Bootstrap

python 的sm4 算法,sm4_apply对str的操作,是对sm4_alg的封装

"""
Author:tanglei
DateTime:2024-10-12
微信:ciss_cedar
欢迎一起学习
面向应用提供字符输入的支持,对sm4_alg 进行二次封装
"""
import base64
import os
from datetime import datetime

from cryptography.hazmat.primitives import hashes

from utils.algorithm.symmetric.sm4_alg import gen_random_bytes
from utils.algorithm.symmetric.sm4_alg import sm4_cbc_encrypt_bytes, sm4_cbc_decrypt_bytes

code_tuple = ('Hex', 'base64')
mode_tuple = ('CBC', 'CTR')


# def sm4_ecb_encrypt(key,source,code=code_tuple[0]):
#     key_bytes=bytes.fromhex(key)
#     plain_byts=source.encode()
#     encrypt_bytes = sm4_ecb_encrypt_bytes(key_bytes,plain_byts)
#     if code==code_tuple[0]:
#         enc_source=encrypt_bytes.hex().upper()
#     else:
#         enc_source=base64.b64encode(encrypt_bytes).decode()
#     return enc_source
#
# def sm4_ecb_decrypt(key,enc_source,code=code_tuple[0]):
#     key_bytes=bytes.fromhex(key)
#     if code==code_tuple[0]:
#         cipher_bytes = bytes.fromhex(enc_source)
#     else :
#         cipher_bytes = base64.b64decode(enc_source)
#
#     plain_bytes = sm4_ecb_decrypt_bytes(key_bytes,cipher_bytes)
#     source=plain_bytes.decode()
#
#     return source

def sm4_cbc_encrypt(key, iv, source, code=code_tuple[0]):
    key_bytes = bytes.fromhex(key)
    iv_bytes = bytes.fromhex(iv)
    plain_bytes = source.encode()
    encrypt_bytes = sm4_cbc_encrypt_bytes(key_bytes, iv_bytes, plain_bytes)

    if code == code_tuple[0]:
        enc_source = encrypt_bytes.hex().upper()
    else:
        enc_source = base64.b64encode(encrypt_bytes).decode()
    return enc_source


def sm4_cbc_decrypt(key, iv, enc_source, code=code_tuple[0]):
    key_bytes = bytes.fromhex(key)
    iv_bytes = bytes.fromhex(iv)
    if code == code_tuple[0]:
        cipher_bytes = bytes.fromhex(enc_source)  # bytes类型
    else:
        cipher_bytes = base64.b64decode(enc_source)
    encrypt_bytes = sm4_cbc_decrypt_bytes(key_bytes, iv_bytes, cipher_bytes)
    source = encrypt_bytes.decode()
    return source


# def sm4_gcm_encrypt(key,source,aad=None,code=code_tuple[0]):
#     key_bytes=bytes.fromhex(key)
#     plain_bytes=source.encode()
#     if aad is not None:
#        aad_bytes=aad.encode()
#     else:
#        aad_bytes=None
#     encrypt_bytes,tag_bytes=sm4_gcm_encrypt_bytes(key_bytes,plain_bytes,aad_bytes)
#     tag=tag_bytes.hex().upper() #只对加密数据base64编码,是长度变短1.33--1.34倍数的原文
#     if code==code_tuple[0]:
#         enc_source=encrypt_bytes.hex().upper()
#     else:
#         enc_source=base64.b64encode(encrypt_bytes).decode()
#     return enc_source,tag
#
# def sm4_gcm_decrypt(key,enc_source, tag, aad=None,code=code_tuple[0]):
#     key_bytes = bytes.fromhex(key)
#     tag_bytes = bytes.fromhex(tag)
#     if aad is not None:
#         aad_bytes = aad.encode()
#     else:
#         aad_bytes = None
#     if code==code_tuple[0]:
#        cipher_bytes = bytes.fromhex(enc_source)  # bytes类型
#     else:
#        cipher_bytes = base64.b64decode(enc_source)
#     encrypt_bytes = sm4_gcm_decrypt_bytes(key_bytes,cipher_bytes, tag_bytes, aad_bytes)
#     source=encrypt_bytes.decode()
#     return source

def file_encrypt_sm4(key: str, iv: str, source_file: str, mode=mode_tuple[0], filepath=''):
    """

    :param key: sm4_cbc key str 16进制
    :param iv: sm4_cbc iv str 16进制
    :param source_file: 需要加密的文件的完整文件名称
    :param mode:cbc,ctr 默认cbc
    :return: source_sm3_hash:原文件的sm3值,
             target_file 加密后的完整文件名
    """
    key_bytes = bytes.fromhex(key)
    iv_bytes = bytes.fromhex(iv)
    suffix = datetime.now().strftime('%Y%m%d%f')
    suffix = '.' + suffix + '_Encrypt'
    if filepath == '':
        target_file = source_file + suffix
    else:
        filename = os.path.basename(source_file)
        target_file = filepath + '/' + filename + suffix

    read_num = 4096
    algorithm = hashes.SM3()
    h_digest = hashes.Hash(algorithm)
    with open(source_file, mode='rb') as source:
        with open(target_file, mode='wb') as target:
            data = source.read(read_num)
            while data:
                h_digest.update(data)
                if mode == mode_tuple[0]:
                    # inner_alg_name = AlgName.SM4.value
                    # inner_mode = CipherMode.CBC.value
                    # my_alg = MySymCipher(key_bytes, iv_bytes, inner_mode, inner_alg_name)
                    # data_bytes = my_alg.encrypt_bytes(data)
                    data_bytes = sm4_cbc_encrypt_bytes(key_bytes, iv_bytes, data)
                # else:
                #
                #     inner_alg_name = AlgName.SM4.value
                #     inner_mode = CipherMode.CTR.value
                #     my_alg = MySymCipher(key_bytes, iv_bytes, inner_mode, inner_alg_name)
                #     data_bytes = my_alg.encrypt_bytes(data)
                target.write(data_bytes)
                data = source.read(read_num)
    sm3_hash_bytes = h_digest.finalize()
    source_sm3_hash = sm3_hash_bytes.hex().upper()
    return source_sm3_hash, target_file


def file_decrypt_sm4(key: str, iv: str, encrypt_file: str, mode=mode_tuple[0], filepath=''):
    """

    :param key: sm4_cbc key str 16进制
    :param iv: sm4_cbc iv str 16进制
    :param encrypt_file: 加密的文件的完整文件名称
    :return: decrypt_sm3_hash:解密文件的sm3值,
             decrypt_file 解密后的完整文件名
    """
    key_bytes = bytes.fromhex(key)
    iv_bytes = bytes.fromhex(iv)

    encrypt_file = encrypt_file.replace('\\', '/')
    j = encrypt_file.rfind('.')
    i = encrypt_file.rfind('/')
    if i < 0:
        i = 0
    filename = encrypt_file[i + 1:j]
    prefix = datetime.now().strftime('%Y%m%d%f')
    prefix = 'Decrypt_' + prefix
    # decrypt_file=''
    if (filepath == ''):
        filepath = encrypt_file[0:i]
        decrypt_file = filepath + '/' + prefix + '.' + filename
    else:
        decrypt_file = filepath + '/' + prefix + '.' + filename
    read_num = 4096 + 16  # PKCS填充多16个字节。多读取16个字节

    algorithm = hashes.SM3()
    h_digest = hashes.Hash(algorithm)

    with open(encrypt_file, mode='rb') as source:
        with open(decrypt_file, mode='wb') as target:
            data = source.read(read_num)
            while data:
                if mode == mode_tuple[0]:
                    data_bytes = sm4_cbc_decrypt_bytes(key_bytes, iv_bytes, data)
                # else:
                #     inner_alg_name = AlgName.SM4.value
                #     inner_mode = CipherMode.CTR.value
                #     my_alg = MySymCipher(key_bytes, iv_bytes, inner_mode, inner_alg_name)
                #     data_bytes = my_alg.decrypt_bytes(data)

                h_digest.update(data_bytes)
                target.write(data_bytes)
                data = source.read(read_num)
    sm3_hash_bytes = h_digest.finalize()
    decrypt_sm3_hash = sm3_hash_bytes.hex().upper()
    return decrypt_sm3_hash, decrypt_file


def gen_sm4_key(length=16):
    key = gen_random_bytes(length).hex().upper()
    return key


def gen_sm4_iv(length=16):
    iv = gen_random_bytes(length).hex().upper()
    return iv


def gen_sm4_key_iv(length=16):
    key = gen_random_bytes(length).hex().upper()
    iv = gen_random_bytes(length).hex().upper()
    return key, iv


if __name__ == '__main__':

    # key=gen_sm4_key()
    # print(f'gen_sm4_key={key}')

    key = '2934412A66B7A186DC35DC40E926F9EE'
    iv = '86CD720D75F4622DBE96078A3CD1076E'
    source = '1'

    # sm4_ecb=sm4_ecb_encrypt(key,source)
    # print(f'key={key}')
    # print(f'source={source}')
    # print(f'sm4_ecb={sm4_ecb}')
    # sm4_ecb_dec=sm4_ecb_decrypt(key,sm4_ecb)
    # print(f'sm4_ecb_dec={sm4_ecb_dec}')
    # print('-' * 66)
    # sm4_ecb_base64 = sm4_ecb_encrypt(key, source,code='base64')
    # print(f'sm4_ecb_base64={sm4_ecb_base64}')
    # sm4_ecb_dec_base64=sm4_ecb_decrypt(key,sm4_ecb_base64,code='base64')
    # print(f'sm4_ecb_dec_base64={sm4_ecb_dec_base64}')
    # print('='*66)

    # print(f'iv={iv}')
    # sm4_cbc=sm4_cbc_encrypt(key,iv,source)
    # print(f'sm4_cbc={sm4_cbc}')
    # sm4_cbc_dec=sm4_cbc_decrypt(key,iv,sm4_cbc)
    # print(f'sm4_cbc_dec={sm4_cbc_dec}')
    # print('-' * 66)
    # sm4_cbc_base64=sm4_cbc_encrypt(key,iv,source,code='base64')
    # print(f'sm4_cbc_base64={sm4_cbc_base64}')
    # sm4_cbc_dec_base64=sm4_cbc_decrypt(key,iv,sm4_cbc_base64,code='base64')
    # print(f'sm4_cbc_dec_base64={sm4_cbc_dec_base64}')
    # print('=' * 66)
    # aad='abc'
    # sm4_gcm,tag=sm4_gcm_encrypt(key,source,aad)
    # sm4_gcm_dec=sm4_gcm_decrypt(key,sm4_gcm,tag,aad)
    # print(f'aad={aad}')
    # print(f'sm4_gcm={sm4_gcm},tag={tag}')
    # print(f'sm4_gcm_dec={sm4_gcm_dec}')
    # print('-' * 66)
    #
    # aad='abc'
    # sm4_gcm_base64,tag=sm4_gcm_encrypt(key,source,aad,code='base64')
    # sm4_gcm_dec_base64=sm4_gcm_decrypt(key,sm4_gcm_base64,tag,aad,code='base64')
    # print(f'aad={aad}')
    # print(f'sm4_gcm_base64={sm4_gcm_base64},tag={tag}')
    # print(f'sm4_gcm_dec_base64={sm4_gcm_dec_base64}')
    # print('=' * 66)

    source_file1 = r'D:\Python\backup\get_com_cert.py\Anantara.JPG'
    source_file2 = r'D:\Python\backup\get_com_cert.py\ceti.txt'
    source_file3 = r'D:\Python\backup\get_com_cert.py\readme.txt'
    source_file4 = r'D:\Python\backup\get_com_cert.py\人间烟火.mp3'
    source_file5 = r'D:\Python\backup\get_com_cert.py\笑傲江湖.mp4'
    source_file6 = r'D:\Python\backup\get_com_cert.py\商用密码知识与政策干部读本.pdf'
    source_file7 = r'D:\Python\backup\get_com_cert.py\Expirence_log.ldf'
    source_file_couple = (source_file1, source_file2, source_file3, source_file4,
                          source_file5, source_file6)

    # filepath_encrypt=r'D:/Python/backup/get_com_cert.py/encrypt'
    # filepath_decrypt = r'D:/Python/backup/get_com_cert.py/decrypt'

    mode = 'CBC'
    i = 1
    for source_file in source_file_couple:
        start_time = datetime.now()
        source_sm3_hash, encrypt_file = file_encrypt_sm4(key, iv, source_file, mode)
        decrypt_sm3_hash, decrypt_file = file_decrypt_sm4(key, iv, encrypt_file, mode)
        print(f'source_file[{i}]={source_file}')
        print(f'decrypt_sm3_hash={decrypt_sm3_hash}')
        print(f'decrypt_file={decrypt_file}')
        print(f'source_sm3_hash==decrypt_sm3_hash={source_sm3_hash == decrypt_sm3_hash}')
        end_time = datetime.now()
        btw_second = end_time.timestamp() - start_time.timestamp()
        print(f'start_time={start_time},end_time={end_time},btw_second={btw_second}')
        print('-' * 66)
        i = i + 1