Bootstrap

chisel入门初步2_1——乘累加器设计

对于聚合阶段,节点之间有较高的并行性,为此设计了一个乘累加器。假设输入两个乘数A、B,输出为result,则有公式如下:
r e s u l t = A ∗ B + r e s u l t result = A * B + result result=AB+result
乘累加器可有效用于节点聚合阶段,将相邻节点的特征聚合。若特征过多,可以使用多个乘累加器,类似SIMD的方式进行并行执行,这里设计的乘累加器位宽为3*8=24bit,输入的乘数限制为8bit位宽

chisel实现如下
主函数

package FAM
import chisel3._  
import chisel3.util._  
import os.read

// 乘累加器,输入两个乘数,循环累加
class MAC(val DATA_WIDTH: Int = 8) extends Module {  
    val io = IO(new Bundle {
        val accumulate = Input(Bool())      // 使能端
        val clear = Input(Bool())           // 同步的清零信号
        val multiplier1 = Input(SInt(DATA_WIDTH.W))
        val multiplier2 = Input(SInt(DATA_WIDTH.W))
        val result = Output(SInt((3*DATA_WIDTH).W))
        val done = Output(Bool())
    })
    
    // 使用寄存器来存储累加结果
    val regResult = RegInit(0.S((3*DATA_WIDTH).W))
    // 使用寄存器存储done信号
    val doneReg = RegInit(false.B)

    val mul0 = Module(new BoothMultiplierBase4(DATA_WIDTH)) 
    mul0.io.a := io.multiplier1
    mul0.io.b := io.multiplier2
    mul0.io.start := io.accumulate

    when(mul0.io.done | io.clear) {
        when(io.clear) {
            regResult := 0.S
            doneReg := false.B
        }
        .otherwise {
            regResult := regResult + mul0.io.product
            doneReg := true.B
        }
        
    }
    .otherwise {
        doneReg := false.B
    }



    // 将寄存器的值赋给输出端口
    io.result := regResult
    // 寄存器赋值done信号
    io.done := doneReg

}  
  
/* An object extending App to generate the Verilog code*/
object MAC extends App {
    (new chisel3.stage.ChiselStage).emitVerilog(new MAC(), Array("--target-dir", "./verilog/FAM"))
}

booth编码乘法器

package FAM
import chisel3._  
import chisel3.util._  
import javax.xml.transform.OutputKeys

  
class BoothMultiplierBase4(val DATA_WIDTH: Int = 8) extends Module {  
    val io = IO(new Bundle {  
        val a = Input(SInt(DATA_WIDTH.W))  // Signed input a  
        val b = Input(SInt(DATA_WIDTH.W))  // Signed input b  
        val start = Input(Bool())
        val done = Output(Bool())
        val product = Output(SInt((2 * DATA_WIDTH).W)) // Signed output product  
    })  
    
    val state = RegInit(0.U(2.W)) // 状态寄存器,用于跟踪乘法的进度
    val booth_bits = Wire(Vec((DATA_WIDTH / 2), UInt(3.W)))  
    val partial_products = RegInit(VecInit(Seq.fill(DATA_WIDTH / 2)(0.S((2 * DATA_WIDTH).W))))  
    
    // On every positive edge of the clock  
    
    val b_extended = io.b << 1.U // Sign-extend b with an extra 0  
    val a_neg = -io.a                    // Negation of a  
    val a_pos = io.a                     // Positive of a
    val regProduct = RegInit(0.S((2 * DATA_WIDTH).W))

    // 定义状态
    val idle :: caculate :: bubble:: done :: Nil = Enum(4)

    // 根据状态执行不同的操作
    switch(state) {
        is(idle) {
            when(io.start) {
                state := caculate         
            }
        }

        is(caculate) {
            // Calculate Booth bits  
            for (i <- 0 until DATA_WIDTH / 2) {  
                booth_bits(i) := Cat(b_extended(2*i+2), b_extended(2*i+1), b_extended(2*i)) 

                // Calculate partial products based on Booth encoding  
                partial_products(i) := MuxCase(0.S, Array(  
                (booth_bits(i) === 0.U || booth_bits(i) === 7.U) -> 0.S,  
                (booth_bits(i) === 1.U || booth_bits(i) === 2.U)  -> a_pos,  
                (booth_bits(i) === 3.U) -> (a_pos << 1.U),  
                (booth_bits(i) === 4.U) -> (a_neg << 1.U),                 // 此处自动进行符号位的扩展,下同
                (booth_bits(i) === 5.U || booth_bits(i) === 6.U) -> a_neg  
                ))  
            }  
            state := bubble
        }

        is(bubble) {
            state := done
        }
        
        is(done) {
        state := idle
        }
    }

    io.done := (state === done)

    for (i <- 0 until DATA_WIDTH / 2) {  
        booth_bits(i) := Cat(b_extended(2*i+2), b_extended(2*i+1), b_extended(2*i)) 
    }

    // 组合逻辑部分,计算最终的产品
    val finalProduct = partial_products.zipWithIndex.map {
        case (pp, i) => pp << ((2*i).U)
    }.reduce(_+_)
    // 在状态机之外,确保 io.product 在没有状态转换时也有一个值
    io.product := Mux(state === idle, 0.S((2 * DATA_WIDTH).W), finalProduct)


  

}  
  
/* An object extending App to generate the Verilog code*/
object BoothMultiplierBase4 extends App {
    (new chisel3.stage.ChiselStage).emitVerilog(new BoothMultiplierBase4(), Array("--target-dir", "./verilog/FAM"))
}

测试代码

import scala.util.Random
import org.scalatest._
import chiseltest._
import chisel3._
import FAM.SqrtInv

// 乘累加器的测试类  

  
class Power_1_2Test extends FreeSpec with ChiselScalatestTester {
    "Power -1/2 should pass" in {

        test(new SqrtInv)
        .withAnnotations(Seq(WriteVcdAnnotation))  // generate the .vcd waveform file as output
        { c =>
            println("Start Testing")
            for (i <- 0 until 10) {
                val a = Random.nextInt(256)  // 生成0到255之间的随机数  
                val b = Random.nextInt(256) 
                
                c.io.start_point_degree.poke(a.U)   // 将随机数a作为无符号数输入  
                c.io.end_point_degree.poke(b.U)     // 将随机数b作为无符号数输入
                c.io.start.poke(true.B)  
                c.clock.step(2)

                while (c.io.done.peekBoolean() === false) {
                    c.clock.step(1)
                }

                val expectedResult = math.round(256/math.sqrt(a * b))     // 计算预期乘积  
                val actualResult = c.io.out.peek().litValue.toLong // 获取实际乘积   
                /* 
                    c: 这是测试环境中MAC模块的实例。
                    c.io.result: 这是指向模块输出端口result的引用。
                    peek(): 这是一个Chisel测试方法,用于在不推进时钟的情况下读取端口的当前值。
                    litValue: 这是一个方法,用于从Chisel的Data类型中提取实际的Scala值(在这个例子中是BigInt) 
                */
                
                println(s"Iteration: $i, A: $a, B: $b, Expected Result: $expectedResult, Actual Result: $actualResult")  
                assert(actualResult === expectedResult, s"Product is incorrect at iteration $i!\n Start_point_degree is $a, end point degree is $b.\n Expected: $expectedResult, Actual: $actualResult")  
         
            }
        }
    }
}

上述测试代码可生成vcd波形,可使用GTKwave等方式进行打开
输出结果
在vscode中使用metals进行测试,调试控制台端口处显示

MACTest
Start Testing
Iteration: 0, A: 74, B: -126, Expected Result: -9324, Actual Result: -9324
Iteration: 1, A: -52, B: -70, Expected Result: -5684, Actual Result: -5684
Iteration: 2, A: 67, B: -51, Expected Result: -9101, Actual Result: -9101
Iteration: 3, A: -48, B: -40, Expected Result: -7181, Actual Result: -7181
Iteration: 4, A: 60, B: 71, Expected Result: -2921, Actual Result: -2921
Iteration: 5, A: -18, B: -86, Expected Result: -1373, Actual Result: -1373
Iteration: 6, A: 53, B: 79, Expected Result: 2814, Actual Result: 2814
Iteration: 7, A: 51, B: 75, Expected Result: 6639, Actual Result: 6639
Iteration: 8, A: 78, B: 35, Expected Result: 9369, Actual Result: 9369
Iteration: 9, A: 95, B: 102, Expected Result: 19059, Actual Result: 19059
- MACTest should pass
Execution took 1.81s
1 tests, 1 passed
All tests in MACTest passed
;