一、SM2 国密算法介绍:
"""
SM2 国密非对称加密算法,属于椭圆曲线密码体制(ECC)
Author:John
基于椭圆曲线的离散对数难题,目前 SM2 256 bit 加密算法是相当安全的,相当于 RSA 2048 bit 及以上的安全性
有公钥、私钥之分,公钥给别人,可以在一定范围内公开,私钥留给自己,必须保密。由私钥可以计算公钥;由公钥计算私钥,是相当困难的,现阶段是不可能
加密过程:
设需要发送的消息为比特串 M ,klen 为 M 的比特长度。
为了对明文 M 进行加密,作为加密者的用户 A 应实现以下运算步骤:
A1:用随机数发生器产生随机数k∈[1,n-1];
A2:计算椭圆曲线点 C1=[k]G=(x1,y1),([k]G 表示 k*G )将C1的数据类型转换为比特串;
A3:计算椭圆曲线点 S=[h]PB,若S是无穷远点,则报错并退出;
A4:计算椭圆曲线点 [k]PB=(x2,y2),将坐标 x2、y2 的数据类型转换为比特串;
A5:计算t=KDF(x2 ∥ y2, klen),若 t 为全0比特串,则返回 A1;
A6:计算C2 = M ⊕ t;
A7:计算C3 = Hash(x2 ∥ M ∥ y2);
A8:输出密文C = C1 ∥ C2 ∥ C3。
解密过程:
设klen为密文中C2的比特长度。
为了对密文C=C1 ∥ C2 ∥ C3 进行解密,作为解密者的用户 B 应实现以下运算步骤:
B1:从C中取出比特串C1,将C1的数据类型转换为椭圆曲线上的点,验证C1是否满足椭圆曲线方程,若不满足则报错并退出;
B2:计算椭圆曲线点 S=[h]C1,若S是无穷远点,则报错并退出;
B3:计算[dB]C1=(x2,y2),将坐标x2、y2的数据类型转换为比特串;
B4:计算t=KDF(x2 ∥ y2, klen),若t为全0比特串,则报错并退出;
B5:从C中取出比特串C2,计算M′ = C2 ⊕ t;
B6:计算u = Hash(x2 ∥ M′ ∥ y2),从C中取出比特串C3,若u != C3,则报错并退出;
B7:输出明文M′。
原理:
用户 A 持有公钥PB=[dB]G(仅有PB值),用户 B 持有私钥 dB
加密:C1=k*G C2=M⊕(k*PB) 解密:M′=C2 ⊕ (dB*C1) # 这里只叙述基本原理,便于理解
证明:dB*C1=dB*k*G=k*(dB*G)=k*PB 因此,M′=C2 ⊕ (dB*C1)=M⊕(k*PB)⊕(k*PB)=M 得证
注:此实现算法所研究的椭圆曲线是基于域 Fp 上的椭圆曲线
安全参数设置:
随机数 k 和私钥 dB 最好大点,2**50 以上比较安全
successfully!!! 和官方例子加密结果测试一样
"""
二、算法
1、参数
class SM2:
a=0x787968B4FA32C3FD2417842E73BBFEFF2F3C848B6831D7E0EC65228B3937E498
b=0x63E4C6D3B23B0C849CF84241484BFE48F61D59A5B16BA06E6E12D1DA27C5249A
v=256
p= 0x8542D69E4C044F18E8B92435BF6FF7DE457283915C45517D722EDB8B08F1DFC3
Gx=0x421DEBD61B62EAB6746434EBC3CC315E32220B3BADD50BDC4C4E6C147FEDD43D
Gy=0x0680512BCBB42C07D47349D2153B70C4E5D7FDFCBFA36EA1A85841B9E46E09A2
n=0x8542D69E4C044F18E8B92435BF6FF7DD297720630485628D5AE74EE7C32E79B7
group=2048
2、加密过程
def encrypt_base(self,m,publickey):
"""
加密基函数
:param m: int 型明文
:return: 密文 16 进制字符串
"""
binm=bin(m)[2:]
binm='0'*(-len(binm)%8)+binm
mlen = len(binm)
x,y=publickey
while(True):
k=random.randint(1<<50,1<<100)
C1=self.oval_multiply(k,(self.Gx,self.Gy))
binC1=self.point_to_bit(C1)
Cb=self.oval_multiply(k,(x,y))
binCb=self.point_to_bit(Cb,False)
t=self.KDF(binCb[0],mlen)
if t.count('0')!=mlen:
break
binC2=self.bin_xor(binm,t)
C3=self.hash(binCb[1]+binm+binCb[2])
hexC1=hex(int(binC1,2))[2:].zfill(130)
hexC2=hex(int(binC2,2))[2:].zfill(mlen//4)
return hexC1 + hexC2 + C3
def encrypt(self,plain,publickey):
"""
加密主函数
:param plain: 明文字符串
:param publickey: 公钥 {x,y}
:return: 16 进制密文
"""
mstrb=ba.b2a_hex(plain.encode())
mstr16=str(mstrb,'utf8')
grouplen=self.group//4
mlist=[]
i=-1
for i in range(len(mstr16)//grouplen):
mlist.append(int(mstr16[grouplen*i:grouplen*(i+1)],16))
if math.ceil(len(mstr16)/grouplen)!=len(mstr16)//grouplen:
mlist.append(int(mstr16[grouplen*(i+1):],16))
cipher=""
for m in mlist:
cipher+=self.encrypt_base(m,publickey)
return cipher
3、解密过程
def decrypt_base(self,cipher,privatekey):
"""
解密基函数
:param cipher: 16 进制串
:param privatekey: 私钥 dB
:return: 16 进制串
"""
dB=privatekey
clen=len(cipher)
hexC1=cipher[:130]
hexC2=cipher[130:clen-self.v//4]
hexC3=cipher[-self.v//4:]
pointC1=self.hex_to_point(hexC1)
if not self.discover_cipher_true_or_false(pointC1):
exit("密文错误,请重新更换正确的密文")
dB_C1=self.oval_multiply(dB,pointC1)
binC1=self.point_to_bit(dB_C1,False)
binlenC2=len(hexC2)*4
binC2=bin(int(hexC2,16))[2:].zfill(binlenC2)
t=self.KDF(binC1[0],binlenC2)
binm=self.bin_xor(binC2,t)
u=self.hash(binC1[1]+binm+binC1[2])
if u!=hexC3:
exit("解密明文非原明文,程序出现错误")
hexplain=hex(int(binm,2))[2:].zfill(len(binm)//4)
return hexplain
def decrypt(self,cipher,privatekey):
"""
解密主函数
:param cipher: 密文 16 进制串
:param privatekey: 私钥 dB
:return: 明文字符串
"""
grouplen=(self.group+256+520)//4
clist = []
i = -1
for i in range(len(cipher) // grouplen):
clist.append(cipher[grouplen * i:grouplen * (i + 1)])
if math.ceil(len(cipher)/grouplen)!=len(cipher)//grouplen:
clist.append(cipher[grouplen * (i + 1):])
plain16=""
for c in clist:
plain16+=self.decrypt_base(c,privatekey)
plain=ba.a2b_hex(plain16).decode()
return plain
4、密钥扩展函数
def KDF(self,xy,klen):
"""
秘钥扩展函数
:param xy:经过转换后的秘钥 2 进制串
:param klen:明文二进制长度
:return: 扩展后的密钥 2 进制串
"""
ct=1
K=""
for i in range(klen//self.v):
K+=self.hash(xy+bin(ct)[2:].zfill(32))
ct+=1
flengh=klen/self.v
if math.ceil(flengh)!=int(flengh):
K+=self.hash(xy+bin(ct)[2:].zfill(32))[:(klen-self.v*int(flengh))//4]
binK=bin(int(K,16))[2:].zfill(klen)
return binK
def hash(self,str1):
"""
哈希函数 国密 SM2 原版 hash 函数 是 SM3 256 bit,这里只是测试方便
:return: 16 进制串 256 bit
"""
sha=sha256()
sha.update(str1.encode())
return sha.hexdigest()
5、辗转相除法求逆元
def find_inverse_element(self,n,p):
'''
辗转相除法求 n 的逆元,不需要 n<p,但需要 n p 都为正整数
注:辗转相除法求 s,t(as+bt=(a,b) 这里默认 a,b 为正整数)
如果 a>b 且不大于 2**32,执行循环的次数最多为 logn(以 2 为底),故定义一维数组的 a[32],b[32] 的长度为 32
这条信息在 C 语言等需要定义数组长度的情况下有用
:param n: int 型
:param p: 模数
:return: n mod p 的逆元
'''
a = [0] * 1000
b = [0] * 1000
a[0] = p
b[0] = n
i = 0
while (a[i] % b[i]):
a[i + 1] = b[i]
b[i + 1] = a[i] % b[i]
i += 1
i -= 1
shang_a = 1
shang_b = -(a[i] // b[i])
while (i != -1):
if (i >= 1):
tmp = shang_a
shang_a = shang_b
shang_b = tmp - a[i - 1] // b[i - 1] * shang_b
i -= 1
return shang_b % a[0]
6、椭圆曲线上点的加法、乘法运算
def oval_same_add(self,G):
"""
椭圆曲线上的相同坐标的两个点相加
:param G: 生成元,基点
:return: 相加之后的点
"""
x1,y1=G
tmp1=3*x1*x1+self.a
tmp2=self.find_inverse_element(2*y1,self.p)
k=tmp1*tmp2%self.p
x3=(k*k-x1-x1)%self.p
y3=(k*(x1-x3)-y1)%self.p
return x3,y3
def oval_diff_add(self,G1,G2):
"""
椭圆曲线上的不同坐标的两个点相加
:param G1,G2: 两个点坐标
:return: 相加之后的点
"""
x1,y1=G1
x2,y2=G2
tmp1=y2-y1
tmp2=self.find_inverse_element((x2-x1)%self.p,self.p)
k=tmp1*tmp2%self.p
x3=(k*k-x1-x2)%self.p
y3=(k*(x1-x3)-y1)%self.p
return x3,y3
def oval_diff_add_near(self,point,pointBase):
"""
相邻的两个点相加 ,pointBase 为基点,如:23P 点 + 24P 点
:return: 新的点
"""
return self.oval_diff_add(pointBase,self.oval_same_add(point))
def oval_multiply(self,k,G):
"""
椭圆曲线上的点乘以常数 k
:param k: int 型 k*G 中的 k
:param G: 生成元,基点
:return: 相乘之后的点
"""
if k==2:
return self.oval_same_add(G)
if k==3:
return self.oval_diff_add(G,self.oval_same_add(G))
if k%2==0:
return self.oval_same_add(self.oval_multiply(k//2,G))
if k%2==1:
return self.oval_diff_add_near(self.oval_multiply(k//2,G),G)
7、公、私钥生成函数
def key_produce(self):
"""
生成公钥 私钥函数
:return: 公钥:PB {x,y} 私钥:dB
"""
dB=random.randint(1<<50,1<<100)
PB=self.oval_multiply(dB,(self.Gx,self.Gy))
return PB,dB
三、算法流程
my=SM2()
key=my.key_produce()
PB,dB=key
plain='123456'
cipher=my.encrypt(plain,PB)
print("密文:",cipher)
p=my.decrypt(cipher,dB)
print("明文:",p)
if plain==p:
print("加解密成功")
else:
print("失败")
PS:本算法计算的加密值与 国密官方文档实例 一致