Skip to content

SM2 WWMM

Sun Yimin edited this page Feb 23, 2024 · 1 revision

MFMM=Montgomery Friendly modules Montgomery Multiplication

首先NIST P256 / SM2 256 的素数P都是Montgomery Friendly modules。

输入:
X, Y都是Montgomery数值表示
X, Y都用64位的字表示:
X = X3 * 2^192 + X2 * 2^128 + X1 * 2^64 + X0
Y = Y3 * 2^192 + Y2 * 2^128 + Y1 * 2^64 + Y0
0<=X, Y < p

输出:
X * Y * 2^(-256) mod p

acc0, acc1, acc2, acc3, acc4, acc5是64位寄存器

第一步,计算X * Y0

其结果,tmp = acc4 * 2^256 + acc3 * 2^192 + acc2 * 2^128 + acc1 * 2 ^ 64 + acc0。
X 乘以Y的其它高位64位字的结果肯定是 2^64的倍数,所以,T mod 2 ^ 64 = acc0

第二步(first reduction step),计算(tmp + acc0 * p) / 2^64

这里p=p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + p0, 不管NIST P256还是SM2,p0 = 2^64 - 1,
所以我们扩展(tmp + acc0 * p) / 2^64 
= (acc4 * 2^256 + acc3 * 2^192 + acc2 * 2^128 + acc1 * 2 ^ 64 + acc0 + acc0 * ( p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + 2^64 - 1)) / 2^64
= acc4 * 2^192 + acc3 * 2^128 + acc2 * 2^64 + acc1 + acc0 * p3 * 2^128 + acc0*p2*2^64+acc0*p1+acc0
= acc4 * 2^192 + (acc3 + acc0 * p3) * 2^128 + (acc2+acc0*p2) * 2^64 + acc0*p1 + acc1 + acc0

(carry1, acc1) = acc0 + acc1 + acc0 * p1
(carry2, acc2) = carry1 + acc2 + acc0 * p2
(carry3, acc3) = carry2 + acc3 + acc0 * p3
(carry4, acc4) = carry3 + acc4
acc5 = carry4

进位处理后,结果表示成 tmp = acc5 * 2^256 + acc4 * 2^192 + acc3 * 2^128 + acc2 * 2 ^ 64 + acc1

H = high64(acc0*2^32) 超出64位宽部分 , L = low64(acc0*2^32)
NIST P
p = 0xffffffff00000001 0000000000000000 00000000ffffffff ffffffffffffffff
   = p3 * 2^192 + p1 * 2^64 + 2^64 - 1
p * acc0 = acc0 * p3 * 2^192 + acc0 * p1 * 2^64 + acc0 * 2^64 - acc0
(tmp + acc0 * p) / 2^64 = acc0 * p3 * 2^128 + acc0 * p1 + acc0 + acc4 * 2^192 + acc3 * 2^128 + acc2 * 2^64 + acc1
    =acc4 * 2^192 + (acc0 * p3 + acc3) * 2^128 + acc2 * 2^64 + acc0 + acc1 + acc0* (2^32 - 1)
    =acc4 * 2^192 + (acc0 * p3 + acc3) * 2^128 + acc2 * 2^64 + acc1 + acc0* 2^32 
    =acc4 * 2^192 + (acc0 * p3 + acc3) * 2^128 + (acc2 + H(acc0* 2^32))* 2^64 + acc1 + L(acc0* 2^32) 


amd64 汇编表示为:
MOVQ acc0, AX
MOVQ acc0, t1
SHLQ $32, acc0         // L(acc0 * 2^32) 
MULQ p256const1<>(SB)  // acc0 * p3 = (DX, AX), DX为高64位
SHRQ $32, t1           // t1 = H(acc0 * 2^32)
ADDQ acc0, acc1        // (carry1, acc1) = acc1 + L(acc0 * 2^32)
ADCQ t1, acc2          // (carry2, acc2) = carry1 + acc2 + H(acc0 * 2^32)
ADCQ AX, acc3          // (carry3, acc3) = carry2 + acc3 + L(acc0 * p3)
ADCQ DX, acc4          // (carry4, acc4) = carry3 + acc4 + H(acc0 * p3)
ADCQ $0, acc5          // acc5 = carry4
XORQ acc0, acc0        // acc0 = 0
结果用五个64位寄存器表示:(acc5, acc4, acc3, acc2, acc1)

arm64 汇编表示为:
ADDS	acc0<<32, acc1, acc1  // (carry1, acc1) = acc1 + L(acc0 * 2^32)
LSR	$32, acc0, t0         // t0 = H(acc0 * 2^32)
MUL	acc0, const1, t1      // t1 = L(acc0 * p3)
UMULH	acc0, const1, acc0    // acc0 = H(acc0 * p3)
ADCS	t0, acc2              // (carry2, acc2) = carry1 + acc2 + H(acc0 * 2^32)
ADCS	t1, acc3              // (carry3, acc3) = carry2 + acc3 + L(acc0 * p3)
ADC	$0, acc0              // acc0 = carry3 + H(acc0 * p3), arm64的实现,((acc0, acc4), acc3, acc2, acc1)表示第一次reduction的结果, 不像amd64那样使用(acc5, acc4, acc3, acc2, acc1)。
结果也用五个64位寄存器表示:(acc4, acc3, acc2, acc1) ,(acc0, 0, 0, 0)
另外,arm64中的ZR表示zero register,ADC	$0, ZR, acc5 表示 acc5 = carry + 0 + 0。

SM2曲线
p = 0x fffffffeffffffff ffffffffffffffff ffffffff00000000 ffffffffffffffff
   =  (2^64 - 2^32  - 1) * 2^192 + (2^64 - 1) * 2^128 + (2^64 - 2^32) * 2^64 + (2^64 - 1)
   =  (2^64 - 2^32  - 1) * 2^192 + (2^64 - 1) * 2^128 + (2^64 - 2^32 + 1) * 2^64  - 1
   =  (2^64 - 2^32  - 1) * 2^192 + (2^64 ) * 2^128 + ( - 2^32 + 1) * 2^64  - 1
   =  (2^64 - 2^32 ) * 2^192 +  ( - 2^32 + 1) * 2^64  - 1
   =  2^256 + (-2^32) * 2^192 + (1-2^32)*2^64 - 1

p = p3 * 2^192 + p2*2^128 + p1 * 2^64 + 2^64 - 1
(tmp + acc0 * p) / 2^64 = acc4 * 2^192 + (acc3 + acc0*p3) * 2^128 + (acc2 + acc0*p2) * 2^64 + acc1 + acc0*p1 + acc0

amd64 汇编表示为:
MOVQ p256p<>+0x08(SB), AX
MULQ acc0
ADDQ acc0, acc1             // (carry1, acc1) = acc0 + acc1
ADCQ $0, DX                 // DX = carry1 + H(acc0 * p1)
ADDQ AX, acc1               // (carry2, acc1) = acc0 + acc1 + L(acc0*p1)
ADCQ $0, DX                 // DX = DX + carry2
MOVQ DX, t1                 // t1 = H(acc0 * p1) + carry1 + carry2
MOVQ p256p<>+0x010(SB), AX
MULQ acc0
ADDQ t1, acc2               // (carry3, acc2) = t1 + acc2
ADCQ $0, DX                 // DX = carry3 + H(acc0 * p2)
ADDQ AX, acc2               // (carry4, acc2) = L(acc0 * p2) + L(t1 + acc2)
ADCQ $0, DX                 // DX = DX + carry4
MOVQ DX, t1                 // t1 = H(acc0 * p2) + carry3 + carry4
MOVQ p256p<>+0x018(SB), AX
MULQ acc0
ADDQ t1, acc3               // (carry5, acc3) = t1 + acc3
ADCQ $0, DX                 // DX = carry5 + H(acc0 * p3)
ADDQ AX, acc3               // (carry6, acc3) = L(acc0 * p3) + L(t1 + acc3)
ADCQ DX, acc4               // (carry7, acc4) = acc4 + DX + carry6
ADCQ $0, acc5               // acc5 = carry7
XORQ acc0, acc0

arm64 汇编表示为
MUL	const1, acc0, t0
    ADDS    t0, acc1, acc1       // (carry1, acc1) = acc1 + L(acc0*p1)
UMULH	const1, acc0, y0     // y0 = H(acc0*p1)

MUL	const2, acc0, t0    
ADCS	t0, acc2, acc2       // (carry2, acc2) = acc2 +  L(acc0*p2)
UMULH	const2, acc0, hlp0   // hlp0 = H(acc0*p2)

MUL	const3, acc0, t0    // t0 = L(acc0*p3)
ADCS	t0, acc3, acc3      // (carry3,acc3) = acc3 + L(acc0*p3)

UMULH	const3, acc0, hlp1 // hlp1 = H(acc0*p3), 事实上不能用hlp1, 这个寄存器被p256PointAddAsm方法全局使用
ADC	$0, acc4            // acc4 = carry3 + acc4

ADDS	acc0, acc1, acc1  // (carry4, acc1) = acc0 + acc1 + L(acc0*p1)
ADCS	y0, acc2, acc2    // (carry5, acc2) = carry4 + acc2 +  L(acc0*p2) + H(acc0*p1)
ADCS	hlp0, acc3, acc3  // (carry6, acc3) = carry5 + acc3 + L(acc0*p3) + H(acc0*p2)
ADC	$0, hlp1, acc0    // acc0 = carry6 + H(acc0*p3)

手上没有arm64环境,只能依赖Travis CI检验代码,但是很慢,效率很低,不过改用arm64-graviton后好多了。

======
用加减替代乘法,但存在潜在风险,进位/借位处理太复杂,所以该实现已经被回滚
p*acc0 = acc0*2^256 -(acc0*2^32)*2^192 + (acc0 - acc0*2^32)*2^64 - acc0
(tmp + acc0 * p) / 2^64 = (acc4 * 2^256 + acc3 * 2^192 + acc2 * 2^128 + acc1 * 2 ^ 64 + acc0 + acc0*2^256 -(acc0*2^32)*2^192 + (acc0 - acc0*2^32)*2^64 - acc0? / 2^64
      = (acc4+acc0)*2^192 + (acc3  - acc0*2^32) * 2^128 + acc2 * 2^64 + (acc1 + acc0 - acc0*2^32)
      = (acc4+acc0)*2^192 + (acc3  - acc0*2^32) * 2^128 + (acc2 - H(acc0*2^32)) * 2^64 + (acc1 + acc0 - L(acc0*2^32))

(carry1, acc1) = acc0+acc1
acc2 = carry1 + acc2       // 有可能进位?有可能,当acc2 = 0xffffffffffffffff
(carry2, acc1) = acc1 - L  
acc2 = acc2 - H - carry2  // 有可能借位?有可能,在acc0足够大,acc2足够小的情况下

(carry3, acc3) = acc0 + acc3
t1 = acc0 + carry3  //有可能进位吗?
(carry4, acc3) = acc3 - L  
t1 = t1 - H - carry4  // 会有可能小于0吗?不可能
(carry5, acc3) = acc3 - acc0
t1 = t1 - carry5   // 会有可能小于0吗?不可能

(carry6, acc4) = acc4 + t1
acc5 = carry6
======

最后使用以下算法(主要就是一轮加法,一轮减法),相当有对称性:

   acc4,         acc3,         acc2,        acc1
 + acc0,         0,            0,           acc0
 - H(acc0*2^32)  L(acc0*2^32)  H(acc0*2^32) L(acc0*2^32)

MOVQ acc0, AX
MOVQ acc0, DX
SHLQ $32, AX
SHRQ $32, DX

ADDQ acc0, acc1
ADCQ $0, acc2
ADCQ $0, acc3
ADCQ acc0, acc4
ADCQ $0, acc5
SUBQ AX, acc1
SBBQ DX, acc2
SBBQ AX, acc3
SBBQ DX, acc4
SBBQ $0, acc5

第三步,计算 X * Y1,并且和tmp相加

tmp = tmp + X * Y1,按逐个64位字相加的原则:

tmp = tmp + X0*Y1
tmp = tmp + X1*Y1 * 2^64
tmp = tmp + X2*Y1 * 2^128
tmp = tmp + X3*Y1 * 2^192

(carry1, acc1) = acc1 + X0 * Y1
(carry2, acc2) = acc2 + carry1 + X1 * Y1
(carry3, acc3) = acc3 + carry2 + X2 * Y1
(carry4, acc4) = acc4 + carry3 + X3 * Y1
(carry5, acc5) = acc5 + carry4
acc0 = carry5

最后tmp表示成acc0*2^320 + acc5 * 2^256 + acc4 * 2^192 + acc3 * 2^128 + acc2 * 2 ^ 64 + acc1

第四步(second reduction step)

计算(tmp + acc1 * p) / 2^64,这里p=p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + p0, 不管NIST P256还是SM2,p0 = 2^64 - 1,
所以我们扩展(tmp + acc1 * p) / 2^64 
= (acc0*2^320 + acc5 * 2^256 + acc4 * 2^192 + acc3 * 2^128 + acc2 * 2 ^ 64 + acc1 + acc1 * ( p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + 2^64 - 1)) / 2^64
= acc0*2^256 + acc5 * 2^192 + (acc4 + acc1*p3)*2^128 + (acc3 + acc1*p2)*2^64 + acc1*p1+ acc2 + acc1

(carry1, acc2) = acc1 + acc2 + acc1 * p1
(carry2, acc3) = carry1 + acc3 + acc1 * p2
(carry3, acc4) = carry2 + acc4 + acc1 * p3
(carry4, acc5) = carry3 + acc5
acc0 = acc0 + carry4

进位处理后,结果表示成 tmp = acc0 * 2^256 + acc5 * 2^192 + acc4 * 2^128 + acc3 * 2 ^ 64 + acc2

第五步,计算X * Y2, 并且和tmp相加

tmp = tmp + X * Y2,按逐个64位字相加的原则:
tmp = tmp + X0*Y2
tmp = tmp + X1*Y2 * 2^64
tmp = tmp + X2*Y2 * 2^128
tmp = tmp + X3*Y2 * 2^192

(carry1, acc2) = acc2 + X0 * Y2
(carry2, acc3) = acc3 + carry1 + X1 * Y2
(carry3, acc4) = acc4 + carry2 + X2 * Y2
(carry4, acc5) = acc5 + carry3 + X3 * Y2
(carry5, acc0) = acc0 + carry4
acc1 = carry5

最后tmp表示成acc1*2^320 + acc0 * 2^256 + acc5 * 2^192 + acc4 * 2^128 + acc3 * 2 ^ 64 + acc2

第六步(Third reduction step)

计算(tmp + acc2 * p) / 2^64,这里p=p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + p0, 不管NIST P256还是SM2,p0 = 2^64 - 1,
所以我们扩展(tmp + acc2 * p) / 2^64 
=(acc1*2^320 + acc0 * 2^256 + acc5 * 2^192 + acc4 * 2^128 + acc3 * 2 ^ 64 + acc2 + acc2 * ( p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + 2^64 - 1) ) / 2^64
=acc1*2^256 + acc0*2^192 + (acc5+acc2*p3)*2^128 + (acc4+acc2*p2)*2^64 + acc2 * p1 + acc3 + acc2

(carry1, acc3) = acc2 + acc3 + acc2 * p1
(carry2, acc4) = carry1 + acc4 + acc2 * p2
(carry3, acc5) = carry2 + acc5 + acc2 * p3
(carry4, acc0) = carry3 + acc0
acc1 = acc1 + carry4

进位处理后,结果表示成 tmp = acc1 * 2^256 + acc0 * 2^192 + acc5 * 2^128 + acc4 * 2 ^ 64 + acc3

第七步,计算X * Y3

并且和tmp相加, tmp = tmp + X * Y3,按逐个64位字相加的原则:
tmp = tmp + X0*Y3
tmp = tmp + X1*Y3 * 2^64
tmp = tmp + X2*Y3 * 2^128
tmp = tmp + X3*Y3 * 2^192

(carry1, acc3) = acc3 + X0 * Y3
(carry2, acc4) = acc4 + carry1 + X1 * Y3
(carry3, acc5) = acc5 + carry2 + X2 * Y3
(carry4, acc0) = acc0 + carry3 + X3 * Y3
(carry5, acc1) = acc1 + carry4
acc2 = carry5

最后tmp表示成acc2*2^320 + acc1 * 2^256 + acc0 * 2^192 + acc5 * 2^128 + acc4 * 2 ^ 64 + acc3

第八步(Last reduction step)

计算(tmp + acc3 * p) / 2^64,这里p=p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + p0, 不管NIST P256还是SM2,p0 = 2^64 - 1,
所以我们扩展(tmp + acc2 * p) / 2^64 
=(acc2*2^320 + acc1 * 2^256 + acc0 * 2^192 + acc5 * 2^128 + acc4 * 2 ^ 64 + acc3 + acc3 * ( p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + 2^64 - 1) ) / 2^64
=acc2*2^256 + acc1*2^192 + (acc0+acc3*p3)*2^128 + (acc5+acc3*p2)*2^64 + acc3 * p1 + acc4 + acc3

(carry1, acc4) = acc3 + acc4 + acc3 * p1
(carry2, acc5) = carry1 + acc5 + acc3 * p2
(carry3, acc0) = carry2 + acc0 + acc3 * p3
(carry4, acc1) = carry3 + acc1
acc2 = acc2 + carry4

T = (acc2, acc1, acc0, acc5, acc4)

第九步,如果T >=p,则返回T - p, 否则返回T。

aws arm64-graviton2

go test -v -short -bench . -run=^$ ./...
goos: linux
goarch: arm64
pkg: github.com/emmansun/gmsm/sm2
BenchmarkLessThan32_P256
BenchmarkLessThan32_P256-2      	    3698	    279225 ns/op
BenchmarkLessThan32_P256SM2
BenchmarkLessThan32_P256SM2-2   	    4602	    258525 ns/op
BenchmarkMoreThan32_P256
BenchmarkMoreThan32_P256-2      	    4365	    274304 ns/op
BenchmarkMoreThan32_P256SM2
BenchmarkMoreThan32_P256SM2-2   	    4550	    263296 ns/op
PASS
ok  	github.com/emmansun/gmsm/sm2	4.753s