4

SM3

 2 years ago
source link: https://taardisaa.github.io/2022/03/26/SM3/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

了解一下SM3。

SM3哈希算法

长度256比特,32字节

32比特,4字节长度

$$
A,B,C,D,E,F,G,H
$$

8个字宽的寄存器
$$
B^{(i)}
$$
i个消息分组
$$
CF
$$
Compress Function,压缩函数
$$
FF_j
$$
布尔函数,随j而取不同表达式
$$
GG_j
$$
布尔函数,随j而取不同表达式
$$
IV
$$
Initial Value,初始值,用于确定压缩函数寄存器的初态
$$
P_0
$$
压缩函数中的置换函数
$$
P_1
$$
消息拓展中的置换函数
$$
T_j
$$
常量,根据j变化
$$
m
$$
消息
$$
m’
$$
填充后的消息

部分定义直接写Python函数了,懒得写Latex数学公式。

IV =7380166f 4914b2b9 172442d7 da8a0600 a96f30bc 163138aa e38dee4d b0fb0e4e

Python claripy模块实现向量似乎是个好想法。

self.IV = claripy.BVV(0x7380166f_4914b2b9_172442d7_da8a0600_a96f30bc_163138aa_e38dee4d_b0fb0e4e, 32*8)
def T_j(j:int):
if 0 <= j <= 15:
return 0x79CC4519
elif 16 <= j <= 63:
return 0x7A879D8A
def FF_j(self, X, Y, Z, j:int):
if 0 <= j <= 15:
return X ^ Y ^ Z
elif 16 <= j <= 63:
return (X & Y) | (X & Z) | (Y & Z)

def GG_j(self, X, Y, Z, j:int):
if 0 <= j <= 15:
return X ^ Y ^ Z
elif 16 <= j <= 63:
return (X & Y) | (~X & Z)
def ROL32(self, Num, i):
i %= 32
# if m != 0:
# print("Here!")
# print(m)
if isinstance(Num, int):
return ((Num << i) | (Num >> (32-i)))

Num = Num.zero_extend(32)
return ((Num << i) | (Num >> (32-i))).chop(32)[1]

def P_0(self, X):
return (X) ^ self.ROL32(X, 9) ^ self.ROL32(X, 17)

def P_1(self, X):
return (X) ^ self.ROL32(X, 15) ^ self.ROL32(X, 23)

对长度为l(小于2**64比特)的消息m,生成256比特的哈希值

长度l的消息

先添加比特1到末尾

再添加k个0

k要满足
$$
l + 1 + k =448\ mod\ 512
$$
然后再添加64位,该串是长度l的二进制表示。

def _getLenBit(self, m:bytes):
return len(m) * 8

def _getK(self, m:bytes):
"""
l + 1 + k == 448 + 512*m
k == 447 + 512*m - l
"""
i = 0
while True:
if 447 + 512*i - self._getLenBit(m) >= 0:
return 447 + 512*i - self._getLenBit(m)
i += 1

# def _getUint_64_Bytes8(self, m:bytes):
# l = self._getLenBit(m)
# return binascii.unhexlify(hex(l)[2:].zfill(16))


def padding(self, m:bytes):
"""
输入随意字节流
返回claripy.BVV数据对象
"""
k = self._getK(m)
print(k)
l = self._getLenBit(m)
# print((k, l))
# 这块其实不严谨,因为这样就会导致以字节为最小单位,而实际上是可以精确到比特的
# 但是由于大部分情况下都只需要以字节为单位,所以这里不管了
m_ = claripy.BVV(m, len(m)*8)
m_ = m_.concat(claripy.BVV(1, 1))
m_ = m_.concat(claripy.BVV(0, k))
m_ = m_.concat(claripy.BVV(l, 64))
# print(m)
print(m_.length)
assert m_.length % 512 == 0
return m_
# m = bytearray(m)
# m += b'\x01'
# m += b'\x00' * self._getK(m)
# m += self._getUint_64_Bytes8(m)
# assert self._getLenBit(m) % 512 == 0
# return m
def recur(self, m):
"""
迭代过程
"""
# s = claripy.Solver()
# binascii.unhexlify(hex(s.eval(a,1)[0])[2:])
self.B = m.chop(512)
self.n = len(self.B)

# V, B都是claripy.BVV列表
self.V = [0] * (self.n+1)
self.V[0] = self.IV
for i in range(self.n):
# 512, 512
self.V[i+1] = self.CF(self.V[i], self.B[i])

将填充后的消息m'分组为B,一组512比特。

然后进行迭代压缩。其中主要使用CF压缩函数。

B拓展生成132个字,用于CF函数。

def messageExpand(self, B):
"""
消息拓展
"""
self.W = B.chop(32)
self.W += [claripy.BVV(0, 32)] * (68-16)
self.W_alt = [claripy.BVV(0, 32)] * 64

# assert len(self.W) % 68 == 0
for j in range(16, 67+1):
a = self.W[j-16]
a ^= self.W[j-9]
# print(self.W[j-3])
# print(j)
a ^= self.ROL32(self.W[j-3], 15)
a = self.P_1(a)
a ^= self.ROL32(self.W[j-13], 7)
a ^= self.W[j-6]
self.W[j] = a
# self.W[j] = self.P_1( self.W[j-16] ^ self.W[j-9] ^ self.ROL32(self.W[j-3], 15)) ^ self.ROL32(self.W[j-13], 7) ^ self.W[j-6]
for j in range(63+1):
self.W_alt[j] = self.W[j] ^ self.W[j+4]
def CF(self, inV, inB):
# 8*32比特
# 256比特
self.messageExpand(inB)

A, B, C, D, E, F, G, H = inV.chop(32)
for j in range(63+1):
print((A, B, C, D, E, F, G, H))
SS1 = self.ROL32(self.ROL32(A, 12) + E + self.ROL32(self.T_j(j), j), 7)
SS2 = SS1 ^ self.ROL32(A, 12)
TT1 = self.FF_j(A, B, C, j) + D + SS2 + self.W_alt[j]
TT2 = self.GG_j(E, F, G, j) + H + SS1 + self.W[j]
D = C
C = self.ROL32(B, 9)
B = A
A = TT1
H = G
G = self.ROL32(F, 19)
F = E
E = self.P_0(TT2)
print((A, B, C, D, E, F, G, H))
return (A.concat(B.concat(C.concat(D.concat(E.concat(F.concat(G.concat(H)))))))) ^ inV

256比特

self.V[n]

Python源码

这是我自己实现的。

# from ctypes import c_uint32
import binascii
import claripy

class SM3:
def __init__(self) -> None:
self.IV = claripy.BVV(0x7380166f_4914b2b9_172442d7_da8a0600_a96f30bc_163138aa_e38dee4d_b0fb0e4e, 32*8)
self.V = [
self.IV,
]
# self.V = [
# 0x7380166f,
# 0x4914b2b9,
# 0x172442d7,
# 0xda8a0600,
# 0xa96f30bc,
# 0x163138aa,
# 0xe38dee4d,
# 0xb0fb0e4e
# ]
self.B = []

def T_j(self, j:int):
if 0 <= j <= 15:
return 0x79CC4519
elif 16 <= j <= 63:
return 0x7A879D8A

def FF_j(self, X, Y, Z, j:int):
if 0 <= j <= 15:
return X ^ Y ^ Z
elif 16 <= j <= 63:
return (X & Y) | (X & Z) | (Y & Z)

def GG_j(self, X, Y, Z, j:int):
if 0 <= j <= 15:
return X ^ Y ^ Z
elif 16 <= j <= 63:
return (X & Y) | (~X & Z)

def ROL32(self, Num, i):
i %= 32
# if m != 0:
# print("Here!")
# print(m)
if isinstance(Num, int):
return ((Num << i) | (Num >> (32-i)))

Num = Num.zero_extend(32)
return ((Num << i) | (Num >> (32-i))).chop(32)[1]

def P_0(self, X):
return (X) ^ self.ROL32(X, 9) ^ self.ROL32(X, 17)

def P_1(self, X):
return (X) ^ self.ROL32(X, 15) ^ self.ROL32(X, 23)

def _getLenBit(self, m:bytes):
return len(m) * 8

def _getK(self, m:bytes):
"""
l + 1 + k == 448 + 512*m
k == 447 + 512*m - l
"""
i = 0
while True:
if 447 + 512*i - self._getLenBit(m) >= 0:
return 447 + 512*i - self._getLenBit(m)
i += 1

# def _getUint_64_Bytes8(self, m:bytes):
# l = self._getLenBit(m)
# return binascii.unhexlify(hex(l)[2:].zfill(16))


def padding(self, m:bytes):
"""
输入随意字节流
返回claripy.BVV数据对象
"""
k = self._getK(m)
print(k)
l = self._getLenBit(m)
# print((k, l))
# 这块其实不严谨,因为这样就会导致以字节为最小单位,而实际上是可以精确到比特的
# 但是由于大部分情况下都只需要以字节为单位,所以这里不管了
m_ = claripy.BVV(m, len(m)*8)
m_ = m_.concat(claripy.BVV(1, 1))
m_ = m_.concat(claripy.BVV(0, k))
m_ = m_.concat(claripy.BVV(l, 64))
# print(m)
print(m_.length)
assert m_.length % 512 == 0
return m_
# m = bytearray(m)
# m += b'\x01'
# m += b'\x00' * self._getK(m)
# m += self._getUint_64_Bytes8(m)
# assert self._getLenBit(m) % 512 == 0
# return m

# def splittPaddedMessage(self, m:bytes):
# """
# 填充后的消息m'按512比特分组
# """
# m_ = bytearray(m)

# assert self._getLenBit(m_) % 512 == 0
# # B = []
# if self.B != []:
# raise Exception("B is not empty")

# self.n = 0
# for i in range(self._getLenBit(m_) // 512):
# self.B.append(m_[:64])
# m_ = m_[64:]
# self.n += 1

# return self.B, self.n

def recur(self, m):
"""
迭代过程
"""
# s = claripy.Solver()
# binascii.unhexlify(hex(s.eval(a,1)[0])[2:])
self.B = m.chop(512)
self.n = len(self.B)

# V, B都是claripy.BVV列表
self.V = [0] * (self.n+1)
self.V[0] = self.IV
for i in range(self.n):
# 512, 512
self.V[i+1] = self.CF(self.V[i], self.B[i])


def messageExpand(self, B):
"""
消息拓展
"""
self.W = B.chop(32)
self.W += [claripy.BVV(0, 32)] * (68-16)
self.W_alt = [claripy.BVV(0, 32)] * 64

# assert len(self.W) % 68 == 0
for j in range(16, 67+1):
a = self.W[j-16]
a ^= self.W[j-9]
# print(self.W[j-3])
# print(j)
a ^= self.ROL32(self.W[j-3], 15)
a = self.P_1(a)
a ^= self.ROL32(self.W[j-13], 7)
a ^= self.W[j-6]
self.W[j] = a
# self.W[j] = self.P_1( self.W[j-16] ^ self.W[j-9] ^ self.ROL32(self.W[j-3], 15)) ^ self.ROL32(self.W[j-13], 7) ^ self.W[j-6]
for j in range(63+1):
self.W_alt[j] = self.W[j] ^ self.W[j+4]

def CF(self, inV, inB):
# 8*32比特
# 256比特
self.messageExpand(inB)

A, B, C, D, E, F, G, H = inV.chop(32)
for j in range(63+1):
print((A, B, C, D, E, F, G, H))
SS1 = self.ROL32(self.ROL32(A, 12) + E + self.ROL32(self.T_j(j), j), 7)
SS2 = SS1 ^ self.ROL32(A, 12)
TT1 = self.FF_j(A, B, C, j) + D + SS2 + self.W_alt[j]
TT2 = self.GG_j(E, F, G, j) + H + SS1 + self.W[j]
D = C
C = self.ROL32(B, 9)
B = A
A = TT1
H = G
G = self.ROL32(F, 19)
F = E
E = self.P_0(TT2)
print((A, B, C, D, E, F, G, H))
return (A.concat(B.concat(C.concat(D.concat(E.concat(F.concat(G.concat(H)))))))) ^ inV


sm3 = SM3()
a = sm3.padding(b'abcd'*16)
# print(a)

sm3.recur(a)
# sm3.messageExpand(a)

C语言实现国密SM3算法 - 简书 (jianshu.com)

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define SHL(x,n) (x<<n)
#define SHR(x,n) (x>>n)
#define ROTL(x,n) ((x<<n)|(x>>(32-n)))
#define ROTR(x,n) ((x>>n)|(x<<(32-n)))

#define P1(a,b,c,d,e) (P2((a)^(b)^ROTL((c),15))^ROTL((d),7)^e)
#define P2(a) ((a)^ROTL((a),15)^ROTL((a),23))
#define P3(a,b) ((a)^(b))
#define P4(a) ((a)^ROTL((a),9)^ROTL((a),17))
#define T1 (0x79cc4519)
#define T2 (0x7a879d8a)
#define FF1(a,b,c) ((a)^(b)^(c))
#define FF2(a,b,c) (((a)&(b))|((a)&(c))|((b)&(c)))
#define GG1(a,b,c) ((a)^(b)^(c))
#define GG2(a,b,c) (((a)&(b))|((~a)&(c)))
#define SS1(a,b,c,d) (ROTL((ROTL((a),12)+b+ROTL((c),(d))),7))
#define SS2(a,b,c,d) (SS1((a),(b),(c),(d))^ROTL((a),12))
#define TT1(e,f,g,a,b,c,d) ((e)+(f)+SS2(a,b,c,d)+(g))
#define TT2(e,f,g,a,b,c,d) ((e)+(f)+SS1(a,b,c,d)+(g))

unsigned long H[8] = {0x7380166f, 0x4914b2b9, 0x172442d7, 0xda8a0600, 0xa96f30bc, 0x163138aa, 0xe38dee4d, 0xb0fb0e4e};

int print_str(unsigned char *str, int len)
{
int i = 0;

printf("str=[");

for(i=0; i<len; i++)
{
printf("%02X", str[i]);
}

printf("], len=[%d]\n", len);

return 0;
}

int sm3_long_to_str(unsigned long a, unsigned char *b)
{
unsigned long x = a;
unsigned char *d = (unsigned char *)&x;

b[0] = d[3];
b[1] = d[2];
b[2] = d[1];
b[3] = d[0];

return 0;
}

unsigned long sm3_str_to_long(unsigned char *a)
{
unsigned long x = 0;
unsigned char *b = (unsigned char *)&x;

b[0] = a[3];
b[1] = a[2];
b[2] = a[1];
b[3] = a[0];

return x;
}

int sm3_pad_message(unsigned char *str, int len)
{
unsigned long high, low;
int u = len % 64;

high = 0;
low = len * 8;

if(u < 56)
{
str[len++] = 0x80;
u++;

while(u < 56)
{
str[len++] = 0x00;
u++;
}
}
else if(u > 56)
{
str[len++] = 0x80;
u++;

while(u < 56+64)
{
str[len++] = 0x00;
u++;
}
}

//printf("len=[%08x]\n", low);

str[len++] = high >> 24;
str[len++] = high >> 16;
str[len++] = high >> 8;
str[len++] = high;
str[len++] = low >> 24;
str[len++] = low >> 16;
str[len++] = low >> 8;
str[len++] = low;

return len;
}

int sm3_group_a(unsigned char *a, unsigned char *b, unsigned char *c, unsigned char *d, unsigned char *e, unsigned char *f)
{
unsigned long x[6] = {0};

x[0] = sm3_str_to_long(a);
x[1] = sm3_str_to_long(b);
x[2] = sm3_str_to_long(c);
x[3] = sm3_str_to_long(d);
x[4] = sm3_str_to_long(e);
x[5] = P1(x[0],x[1],x[2],x[3],x[4]);

sm3_long_to_str(x[5], f);

return 0;
}

int sm3_group_b(unsigned char *a, unsigned char *b, unsigned char *c)
{
unsigned long x[3] = {0};

x[0] = sm3_str_to_long(a);
x[1] = sm3_str_to_long(b);
x[2] = P3(x[0],x[1]);

sm3_long_to_str(x[2], c);

return 0;
}

int sm3_str_group(unsigned char *str, int len)
{
unsigned char M[64];
unsigned char W[68][4];
int u = len / 64;
int v = 64 / 16 * 64 * 2;
int i = 0;
int j = 0;

for(i=u-1; i>=0; i--)
{
memset(M, 0x00, sizeof(M));

memcpy(M, str+i*64, 64);

for(j=0; j<16; j++)
{
memcpy(W[j], M+4*j, 4);
}

for(j=16; j<68; j++)
{
sm3_group_a(W[j-16], W[j-9], W[j-3], W[j-13], W[j-6], W[j]);
}

memset(M, 0x00, sizeof(M));

for(j=0; j<64; j++)
{
sm3_group_b(W[j], W[j+4], M);
memcpy(str+i*v+8*j, W[j], 4);
memcpy(str+i*v+8*j+4, M, 4);
}
}

return u*v;
}

int sm3_str_summ(unsigned char *str, unsigned char *summ, int len)
{
unsigned char W[128][4];
unsigned long A[8] = {0};
unsigned long B[8] = {0};
unsigned long C[8] = {0};
int u = len / 512;
int i = 0;
int j = 0;

memcpy(B, H, sizeof(B));

for(i=0; i<u; i++)
{
for(j=0; j<128; j++)
{
memcpy(W[j], str+i*512+j*4, 4);
}

A[0] = B[0];
A[1] = B[1];
A[2] = B[2];
A[3] = B[3];
A[4] = B[4];
A[5] = B[5];
A[6] = B[6];
A[7] = B[7];

for(j=0; j<16; j++)
{
C[0] = sm3_str_to_long(W[2*j+1]);
C[1] = sm3_str_to_long(W[2*j]);
C[2] = TT1(FF1(A[0],A[1],A[2]),A[3],C[0],A[0],A[4],T1,j);
C[3] = TT2(GG1(A[4],A[5],A[6]),A[7],C[1],A[0],A[4],T1,j);
A[7] = A[6];
A[6] = ROTL(A[5],19);
A[5] = A[4];
A[4] = P4(C[3]);
A[3] = A[2];
A[2] = ROTL(A[1],9);
A[1] = A[0];
A[0] = C[2];
}

for(j=16; j<64; j++)
{
C[0] = sm3_str_to_long(W[2*j+1]);
C[1] = sm3_str_to_long(W[2*j]);
C[2] = TT1(FF2(A[0],A[1],A[2]),A[3],C[0],A[0],A[4],T2,j);
C[3] = TT2(GG2(A[4],A[5],A[6]),A[7],C[1],A[0],A[4],T2,j);
A[7] = A[6];
A[6] = ROTL(A[5],19);
A[5] = A[4];
A[4] = P4(C[3]);
A[3] = A[2];
A[2] = ROTL(A[1],9);
A[1] = A[0];
A[0] = C[2];

//printf("A[0]=[%08X]\n", A[0]);
}

B[0] ^= A[0];
B[1] ^= A[1];
B[2] ^= A[2];
B[3] ^= A[3];
B[4] ^= A[4];
B[5] ^= A[5];
B[6] ^= A[6];
B[7] ^= A[7];
}

sm3_long_to_str(B[0], summ);
sm3_long_to_str(B[1], summ+4);
sm3_long_to_str(B[2], summ+8);
sm3_long_to_str(B[3], summ+12);
sm3_long_to_str(B[4], summ+16);
sm3_long_to_str(B[5], summ+20);
sm3_long_to_str(B[6], summ+24);
sm3_long_to_str(B[7], summ+28);

return 0;
}

int main()
{
unsigned char str[64*8*8] = {0};
unsigned char str_sm3[32];
int len = 5;
int i = 0;

/*for(i=0; i<16; i++)
{
str[4*i+0] = 0x61;
str[4*i+1] = 0x62;
str[4*i+2] = 0x63;
str[4*i+3] = 0x64;
}*/

str[0] = 0x33;
str[1] = 0x66;
str[2] = 0x77;
str[3] = 0x99;

len = sm3_pad_message(str, len);

//print_str(str, len);

len = sm3_str_group(str, len);

//print_str(str, len);

sm3_str_summ(str, str_sm3, len);

//print_str(str, len);

print_str(str_sm3, 32);

return 0;
}

看上去是很一般的哈希流程,跟SHA有异曲同工之妙。

302a3ada057c4a73830536d03e683110.pdf (sca.gov.cn)

python实现sm3算法_mt 2333的博客-CSDN博客_python sm3

C语言实现SM3_嘤3的博客-CSDN博客_c语言实现sm3算法

C语言实现国密SM3算法 - 简书 (jianshu.com)


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK