Bootstrap

[python]国密SM3算法的实现

SM3是国密系列算法中的哈希算法,对于任意长度的输入,它都输出固定的256bit数据,可用于通信过程中的数字认证。
哈希算法实质上是一个单向函数。它要求已知a非常容易求出b,但知道b却无法求出a。在学习单向函数时,有一个极端的概念叫做硬核谓词,它表示为f(x)=0|1,即任意长度的输入x经过函数f后都能得到一个确定的输出——0或者1,但由于输出仅有1个比特,所以无法逆向得到输入数据x。事实上,仅依靠输出我们甚至无法得知x的长度。现实中使用的单向函数输出通常不会是1bit,但原理与硬核谓词是一样的。

SM3算法主要包括数据填充数据扩展迭代压缩三个过程。分别对应于代码中的sm3_fill函数sm3_msg_extend函数sm3_compress函数和sm3_update函数
当然也有一些初始设定:

常量相关设定

在这里插入图片描述

相关函数

在这里插入图片描述

填充规则

在这里插入图片描述

消息扩展规则

在这里插入图片描述

压缩规则

在这里插入图片描述

迭代规则

在这里插入图片描述
最后一轮压缩得到的V就是最终结果。

实现语言:python
版本:3.7

#-------以下函数也可用于其它算法中---------
def rotation_left(x, num):
    # 循环左移
    num %= 32
    left = (x << num) % (2 ** 32)
    right = (x >> (32 - num)) % (2 ** 32)
    result = left ^ right
    return result

def Int2Bin(x, k):
    x = str(bin(x)[2:])
    result = "0" * (k - len(x)) + x
    return result

#-------以上函数也可用于其余算法中-----------

class SM3:

    def __init__(self):
        # 常量初始化
        self.IV = [0x7380166F, 0x4914B2B9, 0x172442D7, 0xDA8A0600, 0xA96F30BC, 0x163138AA, 0xE38DEE4D, 0xB0FB0E4E]
        self.T = [0x79cc4519, 0x7a879d8a]
        self.maxu32 = 2 ** 32
        self.w1 = [0] * 68
        self.w2 = [0] * 64

    def ff(self, x, y, z, j):
        # 布尔函数FF
        result = 0
        if j < 16:
            result = x ^ y ^ z
        elif j >= 16:
            result = (x & y) | (x & z) | (y & z)
        return result

    def gg(self, x, y, z, j):
        # 布尔函数GG
        result = 0
        if j < 16:
            result = x ^ y ^ z
        elif j >= 16:
            result = (x & y) | (~x & z)
        return result

    def p(self, x, mode):
        result = 0
        # 置换函数P
        # 输入参数X的长度为32bit(=1个字)
        # 输入参数mode共两种取值:0和1
        if mode == 0:
            result = x ^ rotation_left(x, 9) ^ rotation_left(x, 17)
        elif mode == 1:
            result = x ^ rotation_left(x, 15) ^ rotation_left(x, 23)
        return result

    def sm3_fill(self, msg):
        # 填充消息,使其长度为512bit的整数倍
        # 输入参数msg为bytearray类型
        # 中间参数msg_new_bin为二进制string类型
        # 输出参数msg_new_bytes为bytearray类型
        length = len(msg)   # msg的长度(单位:byte)
        l = length * 8      # msg的长度(单位:bit)

        num = length // 64
        remain_byte = length % 64
        msg_remain_bin = ""
        msg_new_bytes = bytearray((num + 1) * 64)  ##填充后的消息长度,单位:byte

        # 将原数据存储至msg_new_bytes中
        for i in range(length):
            msg_new_bytes[i] = msg[i]

        # remain部分以二进制字符串形式存储
        remain_bit = remain_byte * 8     #单位:bit
        for i in range(remain_byte):
            msg_remain_bin += "{:08b}".format(msg[num * 64 + i])

        k = (448 - l - 1) % 512
        while k < 0:
            # k为满足 l + k + 1 = 448 % 512 的最小非负整数
            k += 512

        msg_remain_bin += "1" + "0" * k + Int2Bin(l, 64)

        for i in range(0, 64 - remain_byte):
            str = msg_remain_bin[i * 8 + remain_bit: (i + 1) * 8 + remain_bit]
            temp = length + i
            msg_new_bytes[temp] = int(str, 2) #将2进制字符串按byte为组转换为整数
        return msg_new_bytes

    def sm3_msg_extend(self, msg):
        # 扩展函数: 将512bit的数据msg扩展为132个字(w1共68个字,w2共64个字)
        # 输入参数msg为bytearray类型,长度为512bit=64byte
        for i in range(0, 16):
            self.w1[i] = int.from_bytes(msg[i * 4:(i + 1) * 4], byteorder="big")

        for i in range(16, 68):
            self.w1[i] = self.p(self.w1[i-16] ^ self.w1[i-9] ^ rotation_left(self.w1[i-3], 15), 1) ^ rotation_left(self.w1[i-13], 7) ^ self.w1[i-6]

        for i in range(64):
            self.w2[i] = self.w1[i] ^ self.w1[i+4]

        # 测试扩展数据w1和w2
        # print("w1:")
        # for i in range(0, len(self.w1), 8):
        #     print(hex(self.w1[i]))
        # print("w2:")
        # for i in range(0, len(self.w2), 8):
        #     print(hex(self.w2[i]))

    def sm3_compress(self,msg):
        # 压缩函数
        # 输入参数v为初始化参数,类型为bytes/bytearray,大小为256bit
        # 输入参数msg为512bit的待压缩数据

        self.sm3_msg_extend(msg)
        ss1 = 0

        A = self.IV[0]
        B = self.IV[1]
        C = self.IV[2]
        D = self.IV[3]
        E = self.IV[4]
        F = self.IV[5]
        G = self.IV[6]
        H = self.IV[7]

        for j in range(64):
            if j < 16:
                ss1 = rotation_left((rotation_left(A, 12) + E + rotation_left(self.T[0], j)) % self.maxu32, 7)
            elif j >= 16:
                ss1 = rotation_left((rotation_left(A, 12) + E + rotation_left(self.T[1], j)) % self.maxu32, 7)
            ss2 = ss1 ^ rotation_left(A, 12)
            tt1 = (self.ff(A, B, C, j) + D + ss2 + self.w2[j]) % self.maxu32
            tt2 = (self.gg(E, F, G, j) + H + ss1 + self.w1[j]) % self.maxu32
            D = C
            C = rotation_left(B, 9)
            B = A
            A = tt1
            H = G
            G = rotation_left(F, 19)
            F = E
            E = self.p(tt2, 0)

            # 测试IV的压缩中间值
            # print("j= %d:" % j, hex(A)[2:], hex(B)[2:], hex(C)[2:], hex(D)[2:], hex(E)[2:], hex(F)[2:], hex(G)[2:], hex(H)[2:])

        self.IV[0] ^= A
        self.IV[1] ^= B
        self.IV[2] ^= C
        self.IV[3] ^= D
        self.IV[4] ^= E
        self.IV[5] ^= F
        self.IV[6] ^= G
        self.IV[7] ^= H

    def sm3_update(self, msg):
        # 迭代函数
        # 输入参数msg为bytearray类型
        # msg_new为bytearray类型
        msg_new = self.sm3_fill(msg)   # msg_new经过填充后一定是512的整数倍
        n = len(msg_new) // 64         # n是整数,n>=1

        for i in range(0, n):
            self.sm3_compress(msg_new[i * 64:(i + 1) * 64])

    def sm3_final(self):
        digest_str = ""
        for i in range(len(self.IV)):
            digest_str += hex(self.IV[i])[2:]

        return digest_str.upper()

    def hashFile(self, filename):
        with open(filename,'rb') as fp:
            contents = fp.read()
            self.sm3_update(bytearray(contents))
        return self.sm3_final()

if __name__ == "__main__":

    msg1 = bytearray(b"abc")
    print("msg1:", msg1.hex(), len(msg1))

    test1 = SM3()
    test1.sm3_update(msg1)
    digest1 = test1.sm3_final()
    print("digest1:", digest1)
    
    msg2 = bytearray(b'abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd')
    msg2 = bytes(msg2)
    print("msg2:", msg2.hex(), len(msg2))

    test2 = SM3()
    test2.sm3_update(msg2)
    digest2 = test2.sm3_final()
    print("digest2:", digest2)

    # 求大小为48M的文件的摘要,大约需要7分钟
    # test3 = SM3()
    # file_digest = test3.hashFile("test.exe")
    # print('file_digest', file_digest)
;