对于聚合阶段,节点之间有较高的并行性,为此设计了一个乘累加器。假设输入两个乘数A、B,输出为result,则有公式如下:
r
e
s
u
l
t
=
A
∗
B
+
r
e
s
u
l
t
result = A * B + result
result=A∗B+result
乘累加器可有效用于节点聚合阶段,将相邻节点的特征聚合。若特征过多,可以使用多个乘累加器,类似SIMD的方式进行并行执行,这里设计的乘累加器位宽为3*8=24bit,输入的乘数限制为8bit位宽
chisel实现如下
主函数
package FAM
import chisel3._
import chisel3.util._
import os.read
// 乘累加器,输入两个乘数,循环累加
class MAC(val DATA_WIDTH: Int = 8) extends Module {
val io = IO(new Bundle {
val accumulate = Input(Bool()) // 使能端
val clear = Input(Bool()) // 同步的清零信号
val multiplier1 = Input(SInt(DATA_WIDTH.W))
val multiplier2 = Input(SInt(DATA_WIDTH.W))
val result = Output(SInt((3*DATA_WIDTH).W))
val done = Output(Bool())
})
// 使用寄存器来存储累加结果
val regResult = RegInit(0.S((3*DATA_WIDTH).W))
// 使用寄存器存储done信号
val doneReg = RegInit(false.B)
val mul0 = Module(new BoothMultiplierBase4(DATA_WIDTH))
mul0.io.a := io.multiplier1
mul0.io.b := io.multiplier2
mul0.io.start := io.accumulate
when(mul0.io.done | io.clear) {
when(io.clear) {
regResult := 0.S
doneReg := false.B
}
.otherwise {
regResult := regResult + mul0.io.product
doneReg := true.B
}
}
.otherwise {
doneReg := false.B
}
// 将寄存器的值赋给输出端口
io.result := regResult
// 寄存器赋值done信号
io.done := doneReg
}
/* An object extending App to generate the Verilog code*/
object MAC extends App {
(new chisel3.stage.ChiselStage).emitVerilog(new MAC(), Array("--target-dir", "./verilog/FAM"))
}
booth编码乘法器
package FAM
import chisel3._
import chisel3.util._
import javax.xml.transform.OutputKeys
class BoothMultiplierBase4(val DATA_WIDTH: Int = 8) extends Module {
val io = IO(new Bundle {
val a = Input(SInt(DATA_WIDTH.W)) // Signed input a
val b = Input(SInt(DATA_WIDTH.W)) // Signed input b
val start = Input(Bool())
val done = Output(Bool())
val product = Output(SInt((2 * DATA_WIDTH).W)) // Signed output product
})
val state = RegInit(0.U(2.W)) // 状态寄存器,用于跟踪乘法的进度
val booth_bits = Wire(Vec((DATA_WIDTH / 2), UInt(3.W)))
val partial_products = RegInit(VecInit(Seq.fill(DATA_WIDTH / 2)(0.S((2 * DATA_WIDTH).W))))
// On every positive edge of the clock
val b_extended = io.b << 1.U // Sign-extend b with an extra 0
val a_neg = -io.a // Negation of a
val a_pos = io.a // Positive of a
val regProduct = RegInit(0.S((2 * DATA_WIDTH).W))
// 定义状态
val idle :: caculate :: bubble:: done :: Nil = Enum(4)
// 根据状态执行不同的操作
switch(state) {
is(idle) {
when(io.start) {
state := caculate
}
}
is(caculate) {
// Calculate Booth bits
for (i <- 0 until DATA_WIDTH / 2) {
booth_bits(i) := Cat(b_extended(2*i+2), b_extended(2*i+1), b_extended(2*i))
// Calculate partial products based on Booth encoding
partial_products(i) := MuxCase(0.S, Array(
(booth_bits(i) === 0.U || booth_bits(i) === 7.U) -> 0.S,
(booth_bits(i) === 1.U || booth_bits(i) === 2.U) -> a_pos,
(booth_bits(i) === 3.U) -> (a_pos << 1.U),
(booth_bits(i) === 4.U) -> (a_neg << 1.U), // 此处自动进行符号位的扩展,下同
(booth_bits(i) === 5.U || booth_bits(i) === 6.U) -> a_neg
))
}
state := bubble
}
is(bubble) {
state := done
}
is(done) {
state := idle
}
}
io.done := (state === done)
for (i <- 0 until DATA_WIDTH / 2) {
booth_bits(i) := Cat(b_extended(2*i+2), b_extended(2*i+1), b_extended(2*i))
}
// 组合逻辑部分,计算最终的产品
val finalProduct = partial_products.zipWithIndex.map {
case (pp, i) => pp << ((2*i).U)
}.reduce(_+_)
// 在状态机之外,确保 io.product 在没有状态转换时也有一个值
io.product := Mux(state === idle, 0.S((2 * DATA_WIDTH).W), finalProduct)
}
/* An object extending App to generate the Verilog code*/
object BoothMultiplierBase4 extends App {
(new chisel3.stage.ChiselStage).emitVerilog(new BoothMultiplierBase4(), Array("--target-dir", "./verilog/FAM"))
}
测试代码
import scala.util.Random
import org.scalatest._
import chiseltest._
import chisel3._
import FAM.SqrtInv
// 乘累加器的测试类
class Power_1_2Test extends FreeSpec with ChiselScalatestTester {
"Power -1/2 should pass" in {
test(new SqrtInv)
.withAnnotations(Seq(WriteVcdAnnotation)) // generate the .vcd waveform file as output
{ c =>
println("Start Testing")
for (i <- 0 until 10) {
val a = Random.nextInt(256) // 生成0到255之间的随机数
val b = Random.nextInt(256)
c.io.start_point_degree.poke(a.U) // 将随机数a作为无符号数输入
c.io.end_point_degree.poke(b.U) // 将随机数b作为无符号数输入
c.io.start.poke(true.B)
c.clock.step(2)
while (c.io.done.peekBoolean() === false) {
c.clock.step(1)
}
val expectedResult = math.round(256/math.sqrt(a * b)) // 计算预期乘积
val actualResult = c.io.out.peek().litValue.toLong // 获取实际乘积
/*
c: 这是测试环境中MAC模块的实例。
c.io.result: 这是指向模块输出端口result的引用。
peek(): 这是一个Chisel测试方法,用于在不推进时钟的情况下读取端口的当前值。
litValue: 这是一个方法,用于从Chisel的Data类型中提取实际的Scala值(在这个例子中是BigInt)
*/
println(s"Iteration: $i, A: $a, B: $b, Expected Result: $expectedResult, Actual Result: $actualResult")
assert(actualResult === expectedResult, s"Product is incorrect at iteration $i!\n Start_point_degree is $a, end point degree is $b.\n Expected: $expectedResult, Actual: $actualResult")
}
}
}
}
上述测试代码可生成vcd波形,可使用GTKwave等方式进行打开
输出结果
在vscode中使用metals进行测试,调试控制台端口处显示
MACTest
Start Testing
Iteration: 0, A: 74, B: -126, Expected Result: -9324, Actual Result: -9324
Iteration: 1, A: -52, B: -70, Expected Result: -5684, Actual Result: -5684
Iteration: 2, A: 67, B: -51, Expected Result: -9101, Actual Result: -9101
Iteration: 3, A: -48, B: -40, Expected Result: -7181, Actual Result: -7181
Iteration: 4, A: 60, B: 71, Expected Result: -2921, Actual Result: -2921
Iteration: 5, A: -18, B: -86, Expected Result: -1373, Actual Result: -1373
Iteration: 6, A: 53, B: 79, Expected Result: 2814, Actual Result: 2814
Iteration: 7, A: 51, B: 75, Expected Result: 6639, Actual Result: 6639
Iteration: 8, A: 78, B: 35, Expected Result: 9369, Actual Result: 9369
Iteration: 9, A: 95, B: 102, Expected Result: 19059, Actual Result: 19059
- MACTest should pass
Execution took 1.81s
1 tests, 1 passed
All tests in MACTest passed