Skip to content

Commit

Permalink
Merge pull request OpenXiangShan#1761 from OpenXiangShan/fix-wb-priority
Browse files Browse the repository at this point in the history
Timing optimizations for Ctrl and EXU
  • Loading branch information
poemonsense authored Sep 2, 2022
2 parents ad87977 + b0b91ec commit 350b5a9
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 81 deletions.
37 changes: 10 additions & 27 deletions src/main/scala/utils/BitUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,7 @@ object OnesMoreThan {
}

abstract class SelectOne {
protected val balance2 = RegInit(false.B)
balance2 := !balance2

// need_balance: for balanced selections only (DO NOT use this if you don't know what it is)
def getNthOH(n: Int, need_balance: Boolean = false): (Bool, Vec[Bool])
def getBalance2: Bool = balance2
def getNthOH(n: Int): (Bool, Vec[Bool])
}

class NaiveSelectOne(bits: Seq[Bool], max_sel: Int = -1) extends SelectOne {
Expand All @@ -271,7 +266,7 @@ class NaiveSelectOne(bits: Seq[Bool], max_sel: Int = -1) extends SelectOne {
}
}

def getNthOH(n: Int, need_balance: Boolean = false): (Bool, Vec[Bool]) = {
def getNthOH(n: Int): (Bool, Vec[Bool]) = {
require(n > 0, s"$n should be positive to select the n-th one")
require(n <= n_sel, s"$n should not be larger than $n_sel")
// bits(i) is true.B and bits(i - 1, 0) has n - 1
Expand All @@ -290,26 +285,14 @@ class CircSelectOne(bits: Seq[Bool], max_sel: Int = -1) extends SelectOne {
val sel_backward = new NaiveSelectOne(bits.reverse, n_sel / 2)
val moreThan = Seq(1, 2).map(i => OnesMoreThan(bits, i))

def getNthOH(n: Int, need_balance: Boolean = false): (Bool, Vec[Bool]) = {
require(!need_balance || max_sel == 2, s"does not support load balance between $max_sel selections")
val selValid = if (!need_balance) {
OnesMoreThan(bits, n)
} else {
if (n == 1) {
// When balance2 bit is set, we prefer the second selection port.
Mux(balance2, moreThan.last, moreThan.head)
}
else {
require(n == 2)
Mux(balance2, moreThan.head, moreThan.last)
}
}
def getNthOH(n: Int): (Bool, Vec[Bool]) = {
val selValid = OnesMoreThan(bits, n)
val sel_index = (n + 1) / 2
if (n % 2 == 1) {
(selValid, sel_forward.getNthOH(sel_index, need_balance)._2)
(selValid, sel_forward.getNthOH(sel_index)._2)
}
else {
(selValid, VecInit(sel_backward.getNthOH(sel_index, need_balance)._2.reverse))
(selValid, VecInit(sel_backward.getNthOH(sel_index)._2.reverse))
}
}
}
Expand All @@ -325,15 +308,15 @@ class OddEvenSelectOne(bits: Seq[Bool], max_sel: Int = -1) extends SelectOne {
val n_odd = n_bits / 2
val sel_odd = new CircSelectOne((0 until n_odd).map(i => bits(2 * i + 1)), (n_sel + 1) / 2)

def getNthOH(n: Int, need_balance: Boolean = false): (Bool, Vec[Bool]) = {
def getNthOH(n: Int): (Bool, Vec[Bool]) = {
val sel_index = (n + 1) / 2
if (n % 2 == 1) {
val selected = sel_even.getNthOH(sel_index, need_balance)
val selected = sel_even.getNthOH(sel_index)
val sel = VecInit((0 until n_bits).map(i => if (i % 2 == 0) selected._2(i / 2) else false.B))
(selected._1, sel)
}
else {
val selected = sel_odd.getNthOH(sel_index, need_balance)
val selected = sel_odd.getNthOH(sel_index)
val sel = VecInit((0 until n_bits).map(i => if (i % 2 == 1) selected._2(i / 2) else false.B))
(selected._1, sel)
}
Expand All @@ -347,7 +330,7 @@ class CenterSelectOne(bits: Seq[Bool], max_sel: Int = -1) extends SelectOne {
def centerReverse(data: Seq[Bool]): Seq[Bool] = data.take(half_index).reverse ++ data.drop(half_index).reverse
val select = new CircSelectOne(centerReverse(bits), max_sel)

def getNthOH(n: Int, need_balance: Boolean): (Bool, Vec[Bool]) = {
def getNthOH(n: Int): (Bool, Vec[Bool]) = {
val selected = select.getNthOH(n)
(selected._1, VecInit(centerReverse(selected._2)))
}
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/xiangshan/backend/exu/Exu.scala
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ abstract class Exu(cfg: ExuConfig)(implicit p: Parameters) extends XSModule {
val fuInReady = config.fuConfigs.zip(fuIn).zip(functionUnits.zip(fuSel)).map { case ((fuCfg, in), (fu, sel)) =>
fu.io.redirectIn := io.redirect

if (fuCfg.hasInputBuffer) {
val buffer = Module(new InputBuffer(8))
if (fuCfg.hasInputBuffer._1) {
val buffer = Module(new InputBuffer(fuCfg.hasInputBuffer._2, fuCfg.hasInputBuffer._3))
buffer.io.redirect <> io.redirect
buffer.io.in.valid := in.valid && sel
buffer.io.in.bits.uop := in.bits.uop
Expand Down
18 changes: 1 addition & 17 deletions src/main/scala/xiangshan/backend/exu/MulDivExeUnit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ package xiangshan.backend.exu
import chipsalliance.rocketchip.config.Parameters
import chisel3._
import chisel3.util._
import xiangshan._
import utils._
import xiangshan._
import xiangshan.backend.fu._

class MulDivExeUnit(implicit p: Parameters) extends ExeUnit(MulDivExeUnitCfg) {
Expand Down Expand Up @@ -68,21 +68,6 @@ class MulDivExeUnit(implicit p: Parameters) extends ExeUnit(MulDivExeUnitCfg) {
mul.ctrl.isHi := isH
mul.ctrl.sign := DontCare

val isDivSign = MDUOpType.isDivSign(func)
val divInputFunc = (x: UInt) => Mux(
isW,
Mux(isDivSign,
SignExt(x(31, 0), XLEN),
ZeroExt(x(31, 0), XLEN)
),
x
)
div.io.in.bits.src(0) := divInputFunc(src1)
div.io.in.bits.src(1) := divInputFunc(src2)
div.ctrl.isHi := isH
div.ctrl.isW := isW
div.ctrl.sign := isDivSign

XSDebug(io.fromInt.valid, "In(%d %d) Out(%d %d) Redirect:(%d %d)\n",
io.fromInt.valid, io.fromInt.ready,
io.out.valid, io.out.ready,
Expand All @@ -94,4 +79,3 @@ class MulDivExeUnit(implicit p: Parameters) extends ExeUnit(MulDivExeUnitCfg) {
io.out.valid, io.out.ready, io.out.bits.data, io.out.bits.uop.cf.pc
)
}

2 changes: 1 addition & 1 deletion src/main/scala/xiangshan/backend/fu/FunctionUnit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ case class FuConfig
latency: HasFuLatency = CertainLatency(0),
fastUopOut: Boolean = false,
fastImplemented: Boolean = false,
hasInputBuffer: Boolean = false,
hasInputBuffer: (Boolean, Int, Boolean) = (false, 0, false),
exceptionOut: Seq[Int] = Seq(),
hasLoadError: Boolean = false,
flushPipe: Boolean = false,
Expand Down
32 changes: 25 additions & 7 deletions src/main/scala/xiangshan/backend/fu/InputBuffer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import utils._
import xiangshan._
import xiangshan.backend.issue.AgeDetector

class InputBuffer(numEntries: Int)(implicit p: Parameters) extends XSModule {
class InputBuffer(numEntries: Int, enableBypass: Boolean)(implicit p: Parameters) extends XSModule {
val io = IO(new Bundle() {
val redirect = Flipped(ValidIO(new Redirect))

Expand All @@ -42,8 +42,12 @@ class InputBuffer(numEntries: Int)(implicit p: Parameters) extends XSModule {
io.in.ready := hasEmpty
val enqVec = selectEnq._2

// bypass
val tryBypass = WireInit(false.B)
val doBypass = WireInit(false.B)

// enqueue
val doEnqueue = io.in.fire && !io.in.bits.uop.robIdx.needFlush(io.redirect)
val doEnqueue = io.in.fire && !doBypass && !io.in.bits.uop.robIdx.needFlush(io.redirect)
when (doEnqueue) {
for (i <- 0 until numEntries) {
when (enqVec(i)) {
Expand All @@ -57,11 +61,13 @@ class InputBuffer(numEntries: Int)(implicit p: Parameters) extends XSModule {
val age = Module(new AgeDetector(numEntries, 1))
age.io.enq(0) := Mux(doEnqueue, enqVec.asUInt, 0.U)

val isEmpty = RegInit(false.B)
isEmpty := !emptyVecNext.asUInt.andR
io.out.valid := isEmpty
val notEmpty = RegInit(false.B)
notEmpty := !emptyVecNext.asUInt.andR
io.out.valid := notEmpty || tryBypass
io.out.bits := Mux1H(age.io.out, data)
when (io.out.fire) {

val doDequeue = io.out.fire && !doBypass
when (doDequeue) {
for (i <- 0 until numEntries) {
when (age.io.out(i)) {
emptyVecNext(i) := true.B
Expand All @@ -70,6 +76,18 @@ class InputBuffer(numEntries: Int)(implicit p: Parameters) extends XSModule {
}
}

// assign bypass signals
if (enableBypass) {
val isEmpty = RegInit(false.B)
isEmpty := emptyVecNext.asUInt.andR

tryBypass := io.in.valid
when (isEmpty) {
io.out.bits := io.in.bits
}
doBypass := io.in.valid && io.out.ready && isEmpty
}

// flush
val flushVec = data.map(_.uop.robIdx).zip(emptyVec).map{ case (r, e) => !e && r.needFlush(io.redirect) }
for (i <- 0 until numEntries) {
Expand All @@ -79,7 +97,7 @@ class InputBuffer(numEntries: Int)(implicit p: Parameters) extends XSModule {
}

val flushDeq = VecInit(flushVec).asUInt
age.io.deq := Mux(io.out.fire, age.io.out, 0.U) | flushDeq
age.io.deq := Mux(doDequeue, age.io.out, 0.U) | flushDeq

val numValid = PopCount(emptyVec.map(e => !e))
XSPerfHistogram("num_valid", numValid, true.B, 0, numEntries, 1)
Expand Down
32 changes: 31 additions & 1 deletion src/main/scala/xiangshan/backend/fu/SRT16Divider.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ package xiangshan.backend.fu
import chipsalliance.rocketchip.config.Parameters
import chisel3._
import chisel3.util._
import utils.SignExt
import utils._
import xiangshan._
import xiangshan.backend.fu.util.CSA3_2

class SRT16DividerDataModule(len: Int) extends Module {
Expand Down Expand Up @@ -465,3 +466,32 @@ class SRT16Divider(len: Int)(implicit p: Parameters) extends AbstractDivider(len
io.out.bits.data := divDataModule.io.out_data
io.out.bits.uop := uopReg
}

class DividerWrapper(len: Int)(implicit p: Parameters) extends FunctionUnit(len) {
val div = Module(new SRT16Divider(len))

div.io <> io

val func = io.in.bits.uop.ctrl.fuOpType
val (src1, src2) = (
io.in.bits.src(0)(XLEN - 1, 0),
io.in.bits.src(1)(XLEN - 1, 0)
)

val isW = MDUOpType.isW(func)
val isH = MDUOpType.isH(func)
val isDivSign = MDUOpType.isDivSign(func)
val divInputFunc = (x: UInt) => Mux(
isW,
Mux(isDivSign,
SignExt(x(31, 0), XLEN),
ZeroExt(x(31, 0), XLEN)
),
x
)
div.io.in.bits.src(0) := divInputFunc(src1)
div.io.in.bits.src(1) := divInputFunc(src2)
div.ctrl.isHi := isH
div.ctrl.isW := isW
div.ctrl.sign := isDivSign
}
26 changes: 19 additions & 7 deletions src/main/scala/xiangshan/backend/issue/BypassNetwork.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,13 @@ import utils._
class BypassInfo(numWays: Int, dataBits: Int) extends Bundle {
val valid = Vec(numWays, Bool())
val data = UInt(dataBits.W)

}

class BypassNetworkIO(numWays: Int, numBypass: Int, dataBits: Int) extends Bundle {
val hold = Input(Bool())
val source = Vec(numWays, Input(UInt(dataBits.W)))
val target = Vec(numWays, Output(UInt(dataBits.W)))
val bypass = Vec(numBypass, Input(new BypassInfo(numWays, dataBits)))

}

class BypassNetwork(numWays: Int, numBypass: Int, dataBits: Int)(implicit p: Parameters)
Expand All @@ -60,13 +58,17 @@ class BypassNetwork(numWays: Int, numBypass: Int, dataBits: Int)(implicit p: Par
class BypassNetworkRight(numWays: Int, numBypass: Int, dataBits: Int)(implicit p: Parameters)
extends BypassNetwork(numWays, numBypass, dataBits) {

val last_cycle_hold = RegInit(false.B)
last_cycle_hold := io.hold

val target_reg = Reg(Vec(numWays, UInt(dataBits.W)))
val bypass_reg = Reg(Vec(numBypass, new BypassInfo(numWays, dataBits)))

when (io.hold) {
target_reg := io.target
// When last cycle holds the data, no need to update it.
when (io.hold && !last_cycle_hold) {
bypass_reg.map(_.valid.map(_ := false.B))
}.otherwise {
target_reg := io.target
}.elsewhen(!io.hold) {
target_reg := io.source
for ((by_reg, by_io) <- bypass_reg.zip(io.bypass)) {
by_reg.data := by_io.data
Expand Down Expand Up @@ -98,7 +100,17 @@ class BypassNetworkLeft(numWays: Int, numBypass: Int, dataBits: Int)(implicit p:
}

object BypassNetwork {
def apply(numWays: Int, numBypass: Int, dataBits: Int, optFirstStage: Boolean)(implicit p: Parameters) = {
Module(new BypassNetworkLeft(numWays, numBypass, dataBits))
def apply(
numWays: Int,
numBypass: Int,
dataBits: Int,
optFirstStage: Boolean
)(implicit p: Parameters): BypassNetwork = {
if (optFirstStage) {
Module(new BypassNetworkLeft(numWays, numBypass, dataBits))
}
else {
Module(new BypassNetworkRight(numWays, numBypass, dataBits))
}
}
}
25 changes: 13 additions & 12 deletions src/main/scala/xiangshan/backend/issue/ReservationStation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ case class RSParams
var hasFeedback: Boolean = false,
var fixedLatency: Int = -1,
var checkWaitBit: Boolean = false,
var optBuf: Boolean = false,
// special cases
var isJump: Boolean = false,
var isAlu: Boolean = false,
Expand All @@ -63,6 +62,7 @@ case class RSParams
def needBalance: Boolean = exuCfg.get.needLoadBalance && exuCfg.get != LdExeUnitCfg
def numSelect: Int = numDeq + numEnq + (if (oldestFirst._1) 1 else 0)
def dropOnRedirect: Boolean = !(isLoad || isStore || isStoreData)
def optDeqFirstStage: Boolean = !exuCfg.get.readFpRf

override def toString: String = {
s"type ${exuCfg.get.name}, size $numEntries, enq $numEnq, deq $numDeq, numSrc $numSrc, fast $numFastWakeup, wakeup $numWakeup"
Expand Down Expand Up @@ -340,6 +340,8 @@ class ReservationStation(params: RSParams)(implicit p: Parameters) extends XSMod
// select the issue instructions
// Option 1: normal selection (do not care about the age)
select.io.request := statusArray.io.canIssue

select.io.balance
// Option 2: select the oldest
val enqVec = VecInit(s0_doEnqueue.zip(s0_allocatePtrOH).map{ case (d, b) => RegNext(Mux(d, b, 0.U)) })
val s1_oldestSel = AgeDetector(params.numEntries, enqVec, statusArray.io.flushed, statusArray.io.canIssue)
Expand Down Expand Up @@ -485,14 +487,6 @@ class ReservationStation(params: RSParams)(implicit p: Parameters) extends XSMod
oldestSelection.io.oldest := s1_in_oldestPtrOH
// By default, we use the default victim index set in parameters.
oldestSelection.io.canOverride := (0 until params.numDeq).map(_ == params.oldestFirst._3).map(_.B)
// When deq width is two, we have a balance bit to indicate selection priorities.
// For better performance, we decide the victim according to selection priorities.
if (params.needBalance && params.oldestFirst._2 && params.numDeq == 2) {
// When balance2 bit is set, selection prefers the second selection port.
// Thus, the first is the victim if balance2 bit is set.
oldestSelection.io.canOverride(0) := select.io.grantBalance
oldestSelection.io.canOverride(1) := !select.io.grantBalance
}
s1_issue_oldest := oldestSelection.io.isOverrided
}

Expand Down Expand Up @@ -754,8 +748,8 @@ class ReservationStation(params: RSParams)(implicit p: Parameters) extends XSMod
}
}

val bypassNetwork = BypassNetwork(params.numSrc, params.numFastWakeup, params.dataBits, params.optBuf)
bypassNetwork.io.hold := !s2_deq(i).ready
val bypassNetwork = BypassNetwork(params.numSrc, params.numFastWakeup, params.dataBits, params.optDeqFirstStage)
bypassNetwork.io.hold := !s2_deq(i).ready || !s1_out(i).valid
bypassNetwork.io.source := s1_out(i).bits.src.take(params.numSrc)
bypassNetwork.io.bypass.zip(wakeupBypassMask.zip(io.fastDatas)).foreach { case (by, (m, d)) =>
by.valid := m
Expand Down Expand Up @@ -892,11 +886,18 @@ class ReservationStation(params: RSParams)(implicit p: Parameters) extends XSMod
}
}

if (select.io.balance.isDefined) {
require(params.numDeq == 2)
val balance = select.io.balance.get
balance.tick := (balance.out && !s1_out(0).fire && s1_out(1).fire) ||
(!balance.out && s1_out(0).fire && !s1_out(1).fire && !io.fromDispatch(0).fire)
}

// logs
for ((dispatch, i) <- io.fromDispatch.zipWithIndex) {
XSDebug(dispatch.valid && !dispatch.ready, p"enq blocked, robIdx ${dispatch.bits.robIdx}\n")
XSDebug(dispatch.fire, p"enq fire, robIdx ${dispatch.bits.robIdx}, srcState ${Binary(dispatch.bits.srcState.asUInt)}\n")
XSPerfAccumulate(s"allcoate_fire_$i", dispatch.fire)
XSPerfAccumulate(s"allocate_fire_$i", dispatch.fire)
XSPerfAccumulate(s"allocate_valid_$i", dispatch.valid)
XSPerfAccumulate(s"srcState_ready_$i", PopCount(dispatch.bits.srcState.map(_ === SrcState.rdy)))
if (params.checkWaitBit) {
Expand Down
Loading

0 comments on commit 350b5a9

Please sign in to comment.