Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replacement: update RRIP #153

Merged
merged 2 commits into from
May 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/main/scala/coupledL2/Common.scala
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@ class ReplacerInfo(implicit p: Parameters) extends L2Bundle {
val channel = UInt(3.W)
val opcode = UInt(3.W)
val reqSource = UInt(MemReqSource.reqSourceBits.W)
val refill_prefetch = Bool()
}

trait HasChannelBits { this: Bundle =>
68 changes: 56 additions & 12 deletions src/main/scala/coupledL2/Directory.scala
Original file line number Diff line number Diff line change
@@ -263,19 +263,47 @@ class Directory(implicit p: Parameters) extends L2Module {
io.replResp.bits.retry := refillRetry

/* ====== Update ====== */
// update replacer only when A hit or refill, at stage 3
val updateHit = reqValid_s3 && hit_s3 && req_s3.replacerInfo.channel(0) &&
// PLRU: update replacer only when A hit or refill, at stage 3
// RRIP: update replacer when A/C hit or refill
val updateHit = if(cacheParams.replacement == "drrip" || cacheParams.replacement == "srrip"){
reqValid_s3 && hit_s3 &&
((req_s3.replacerInfo.channel(0) && (req_s3.replacerInfo.opcode === AcquirePerm || req_s3.replacerInfo.opcode === AcquireBlock || req_s3.replacerInfo.opcode === Hint)) ||
(req_s3.replacerInfo.channel(2) && (req_s3.replacerInfo.opcode === Release || req_s3.replacerInfo.opcode === ReleaseData)))
} else {
reqValid_s3 && hit_s3 && req_s3.replacerInfo.channel(0) &&
(req_s3.replacerInfo.opcode === AcquirePerm || req_s3.replacerInfo.opcode === AcquireBlock)
}
val updateRefill = refillReqValid_s3 && !refillRetry
// update replacer when A/C hit or refill
replacerWen := updateHit || updateRefill

// !!![TODO]!!! check this @CLS
// hit-Promotion, miss-Insertion for RRIP, so refill should hit = false.B
val touch_way_s3 = Mux(refillReqValid_s3, replaceWay, way_s3)
val rrip_hit_s3 = Mux(refillReqValid_s3, false.B, hit_s3)
// origin-bit marks whether the data_block is reused
val origin_bit_opt = if(random_repl) None else
Some(Module(new SRAMTemplate(Bool(), sets, ways, singlePort = true)))
val origin_bits_r = origin_bit_opt.get.io.r(io.read.fire(), io.read.bits.set).resp.data
val origin_bits_hold = Wire(Vec(ways, Bool()))
origin_bits_hold := HoldUnless(origin_bits_r, RegNext(io.read.fire(), false.B))
origin_bit_opt.get.io.w(
replacerWen,
rrip_hit_s3,
req_s3.set,
UIntToOH(touch_way_s3)
)

if(cacheParams.replacement == "srrip"){
val next_state_s3 = repl.get_next_state(repl_state_s3, touch_way_s3, rrip_hit_s3)
// req_type[3]: 0-firstuse, 1-reuse; req_type[2]: 0-acquire, 1-release;
// req_type[1]: 0-non-prefetch, 1-prefetch; req_type[0]: 0-not-refill, 1-refill
val req_type = WireInit(0.U(4.W))
req_type := Cat(origin_bits_hold(touch_way_s3),
req_s3.replacerInfo.channel(2),
(req_s3.replacerInfo.channel(0) && req_s3.replacerInfo.opcode === Hint) || (req_s3.replacerInfo.channel(2) && metaAll_s3(touch_way_s3).prefetch.getOrElse(false.B)) || req_s3.replacerInfo.refill_prefetch,
req_s3.refill
)

val next_state_s3 = repl.get_next_state(repl_state_s3, touch_way_s3, rrip_hit_s3, req_type)
val repl_init = Wire(Vec(ways, UInt(2.W)))
repl_init.foreach(_ := 2.U(2.W))
replacer_sram_opt.get.io.w(
@@ -284,24 +312,40 @@ class Directory(implicit p: Parameters) extends L2Module {
Mux(resetFinish, set_s3, resetIdx),
1.U
)

} else if(cacheParams.replacement == "drrip"){
//Set Dueling
// req_type[3]: 0-firstuse, 1-reuse; req_type[2]: 0-acquire, 1-release;
// req_type[1]: 0-non-prefetch, 1-prefetch; req_type[0]: 0-not-refill, 1-refill
val req_type = WireInit(0.U(4.W))
req_type := Cat(origin_bits_hold(touch_way_s3),
req_s3.replacerInfo.channel(2),
(req_s3.replacerInfo.channel(0) && req_s3.replacerInfo.opcode === Hint) || (req_s3.replacerInfo.channel(2) && metaAll_s3(touch_way_s3).prefetch.getOrElse(false.B)) || req_s3.replacerInfo.refill_prefetch,
req_s3.refill
)

// Set Dueling
val PSEL = RegInit(512.U(10.W)) //32-monitor sets, 10-bits psel
// track monitor sets' hit rate for each policy: srrip-0,128...3968;brrip-64,192...4032
when(refillReqValid_s3 && (set_s3(6,0)===0.U) && !rrip_hit_s3){ //SDMs_srrip miss
// track monitor sets' hit rate for each policy
// basic SDMs complement-selection policy: srrip--set_idx[group-:]==set_idx[group_offset-:]; brrip--set_idx[group-:]==!set_idx[group_offset-:]
val setBits = log2Ceil(sets)
val half_setBits = setBits >> 1
val match_a = set_s3(setBits-1,setBits-half_setBits-1)===set_s3(setBits-half_setBits-1,0) // 512 sets [8:4][4:0]
val match_b = set_s3(setBits-1,setBits-half_setBits-1)===(~set_s3(setBits-half_setBits-1,0))
when(refillReqValid_s3 && match_a && !rrip_hit_s3 && (PSEL=/=1023.U)){ //SDMs_srrip miss
PSEL := PSEL + 1.U
} .elsewhen(refillReqValid_s3 && (set_s3(6,0)===64.U) && !rrip_hit_s3){ //SDMs_brrip miss
} .elsewhen(refillReqValid_s3 && match_b && !rrip_hit_s3 && (PSEL=/=0.U)){ //SDMs_brrip miss
PSEL := PSEL - 1.U
}
// decide use which policy by policy selection counter, for insertion
/* if set -> SDMs: use fix policy
else if PSEL(MSB)==0: use srrip
else if PSEL(MSB)==1: use brrip */
val repl_type = WireInit(false.B)
repl_type := Mux(set_s3(6,0)===0.U, false.B,
Mux(set_s3(6,0)===64.U, true.B,
Mux(PSEL(9)===0.U, false.B, true.B))) // false.B - srrip, true.B - brrip
val next_state_s3 = repl.get_next_state(repl_state_s3, touch_way_s3, rrip_hit_s3, repl_type)
repl_type := Mux(match_a, false.B,
Mux(match_b, true.B,
Mux(PSEL(9)===0.U, false.B, true.B))) // false.B - srrip, true.B - brrip

val next_state_s3 = repl.get_next_state(repl_state_s3, touch_way_s3, rrip_hit_s3, repl_type, req_type)

val repl_init = Wire(Vec(ways, UInt(2.W)))
repl_init.foreach(_ := 2.U(2.W))
2 changes: 1 addition & 1 deletion src/main/scala/coupledL2/L2Param.scala
Original file line number Diff line number Diff line change
@@ -64,7 +64,7 @@ case class L2Param
pageBytes: Int = 4096,
channelBytes: TLChannelBeatBytes = TLChannelBeatBytes(32),
clientCaches: Seq[L1Param] = Nil,
replacement: String = "plru",
replacement: String = "drrip",
mshrs: Int = 16,
releaseData: Int = 3,
/* 0 for dirty alone
1 change: 1 addition & 0 deletions src/main/scala/coupledL2/RequestArb.scala
Original file line number Diff line number Diff line change
@@ -146,6 +146,7 @@ class RequestArb(implicit p: Parameters) extends L2Module {
io.dirRead_s1.bits.replacerInfo.opcode := task_s1.bits.opcode
io.dirRead_s1.bits.replacerInfo.channel := task_s1.bits.channel
io.dirRead_s1.bits.replacerInfo.reqSource := task_s1.bits.reqSource
io.dirRead_s1.bits.replacerInfo.refill_prefetch := s1_needs_replRead && (mshr_task_s1.bits.opcode === HintAck && mshr_task_s1.bits.dsWen)
io.dirRead_s1.bits.refill := s1_needs_replRead
io.dirRead_s1.bits.mshrId := task_s1.bits.mshrId

43 changes: 31 additions & 12 deletions src/main/scala/coupledL2/utils/Replacer.scala
Original file line number Diff line number Diff line change
@@ -38,8 +38,8 @@ abstract class ReplacementPolicy {
def get_next_state(state: UInt, touch_ways: Seq[Valid[UInt]]): UInt = {
touch_ways.foldLeft(state)((prev, touch_way) => Mux(touch_way.valid, get_next_state(prev, touch_way.bits), prev))
}
def get_next_state(state: UInt, touch_way: UInt, hit: Bool): UInt = {0.U}
def get_next_state(state: UInt, touch_way: UInt, hit: Bool, chosen_type: Bool): UInt = {0.U}
def get_next_state(state: UInt, touch_way: UInt, hit: Bool, req_type: UInt): UInt = {0.U}
def get_next_state(state: UInt, touch_way: UInt, hit: Bool, chosen_type: Bool, req_type: UInt): UInt = {0.U}

def get_replace_way(state: UInt): UInt
}
@@ -310,7 +310,6 @@ class SetAssocReplacer(n_sets: Int, n_ways: Int, policy: String) extends SetAsso
def way(set: UInt) = logic.get_replace_way(state_vec(set))
}


// 2-bit static Re-Reference Interval Prediction
class StaticRRIP(n_ways: Int) extends ReplacementPolicy {
def nBits = 2 * n_ways
@@ -323,17 +322,24 @@ class StaticRRIP(n_ways: Int) extends ReplacementPolicy {
def access(touch_ways: Seq[Valid[UInt]]) = {}
def get_next_state(state: UInt, touch_way: UInt) = 0.U //DontCare

override def get_next_state(state: UInt, touch_way: UInt, hit: Bool): UInt = {
override def get_next_state(state: UInt, touch_way: UInt, hit: Bool, req_type: UInt): UInt = {
val State = Wire(Vec(n_ways, UInt(2.W)))
val nextState = Wire(Vec(n_ways, UInt(2.W)))
State.zipWithIndex.map { case (e, i) =>
e := state(2*i+1,2*i)
}
// hit-Promotion, miss-Insertion & Aging
val increcement = 3.U(2.W) - State(touch_way)
// req_type[3]: 0-firstuse, 1-reuse; req_type[2]: 0-acquire, 1-release;
// req_type[1]: 0-non-prefetch, 1-prefetch; req_type[0]: 0-not-refill, 1-refill
// rrpv: non-pref_hit/non-pref_refill(miss)/non-pref_release_reuse = 0;
// pref_hit do nothing; pref_refill = 1; non-pref_release_firstuse/pref_release = 2;
nextState.zipWithIndex.map { case (e, i) =>
e := Mux(i.U === touch_way,
Mux(hit, 0.U(2.W), 2.U(2.W)),
Mux((req_type(2,0) === 0.U && hit) || req_type(2,0) === 1.U || req_type === 12.U, 0.U,
Mux(req_type(2,0) === 3.U, 1.U,
Mux(req_type === 4.U || req_type(2,0) === 6.U, 2.U, State(i)))),
//Mux(hit, 0.U(2.W), 2.U(2.W)),
Mux(hit, State(i), State(i)+increcement)
)
}
@@ -375,7 +381,7 @@ class BRRIP(n_ways: Int) extends ReplacementPolicy {
def access(touch_ways: Seq[Valid[UInt]]) = {}
def get_next_state(state: UInt, touch_way: UInt) = 0.U //DontCare

override def get_next_state(state: UInt, touch_way: UInt, hit: Bool): UInt = {
override def get_next_state(state: UInt, touch_way: UInt, hit: Bool, req_type: UInt): UInt = {
val State = Wire(Vec(n_ways, UInt(2.W)))
val nextState = Wire(Vec(n_ways, UInt(2.W)))
State.zipWithIndex.map { case (e, i) =>
@@ -384,13 +390,26 @@ class BRRIP(n_ways: Int) extends ReplacementPolicy {

// hit-Promotion, miss-Insertion & Aging
val increcement = 3.U(2.W) - State(touch_way)
val random = (rand.nextInt(32)).U
// req_type[3]: 0-firstuse, 1-reuse; req_type[2]: 0-acquire, 1-release;
// req_type[1]: 0-non-prefetch, 1-prefetch; req_type[0]: 0-not-refill, 1-refill
// rrpv: non-pref_hit/non-pref_refill(miss)/non-pref_release_reuse = 0;
// pref_hit do nothing; pref_refill = 1; non-pref_release_firstuse/pref_release = 3;
nextState.zipWithIndex.map { case (e, i) =>
e := Mux(i.U === touch_way,
Mux(hit, 0.U(2.W), Mux(random === 0.U, 2.U(2.W), 3.U(2.W))), //TODO: touch_way=3 for most insertions, touch_rrpv=2 with certain probability
e := Mux(i.U === touch_way,
Mux((req_type(2,0) === 0.U && hit) || req_type(2,0) === 1.U || req_type === 12.U, 0.U,
Mux(req_type(2,0) === 3.U, 1.U,
Mux(req_type === 4.U || req_type(2,0) === 6.U, 3.U, State(i)))),
//Mux(hit, 0.U(2.W), 3.U(2.W)),
Mux(hit, State(i), State(i)+increcement)
)
}
/* val random = (rand.nextInt(32)).U
nextState.zipWithIndex.map { case (e, i) =>
e := Mux(i.U === touch_way,
Mux(hit, 0.U(2.W), Mux(random === 0.U, 2.U(2.W), 3.U(2.W))), //touch_way=3 for most insertions, touch_rrpv=2 with certain probability
Mux(hit, State(i), State(i)+increcement)
)
} */
Cat(nextState.map(x=>x).reverse)
}

@@ -432,8 +451,8 @@ class DRRIP(n_ways: Int) extends ReplacementPolicy {
def hit = {}

def get_next_state(state: UInt, touch_way: UInt) = 0.U //DontCare
override def get_next_state(state: UInt, touch_way: UInt, hit: Bool, chosen_type: Bool): UInt = {
Mux(chosen_type, repl_BRRIP.get_next_state(state, touch_way, hit), repl_SRRIP.get_next_state(state, touch_way, hit))
override def get_next_state(state: UInt, touch_way: UInt, hit: Bool, chosen_type: Bool, req_type: UInt): UInt = {
Mux(chosen_type, repl_BRRIP.get_next_state(state, touch_way, hit, req_type), repl_SRRIP.get_next_state(state, touch_way, hit, req_type))
}
def get_replace_way(state: UInt): UInt = {
val RRPVVec = Wire(Vec(n_ways, UInt(2.W)))
@@ -452,4 +471,4 @@ class DRRIP(n_ways: Int) extends ReplacementPolicy {
PriorityEncoder(lrrWayVec)
}

}
}