-
Notifications
You must be signed in to change notification settings - Fork 62
SM2 WWMM
Sun Yimin edited this page Feb 23, 2024
·
1 revision
MFMM=Montgomery Friendly modules Montgomery Multiplication
首先NIST P256 / SM2 256 的素数P都是Montgomery Friendly modules。
输入:
X, Y都是Montgomery数值表示
X, Y都用64位的字表示:
X = X3 * 2^192 + X2 * 2^128 + X1 * 2^64 + X0
Y = Y3 * 2^192 + Y2 * 2^128 + Y1 * 2^64 + Y0
0<=X, Y < p
输出:
X * Y * 2^(-256) mod p
acc0, acc1, acc2, acc3, acc4, acc5是64位寄存器
其结果,tmp = acc4 * 2^256 + acc3 * 2^192 + acc2 * 2^128 + acc1 * 2 ^ 64 + acc0。
X 乘以Y的其它高位64位字的结果肯定是 2^64的倍数,所以,T mod 2 ^ 64 = acc0
这里p=p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + p0, 不管NIST P256还是SM2,p0 = 2^64 - 1,
所以我们扩展(tmp + acc0 * p) / 2^64
= (acc4 * 2^256 + acc3 * 2^192 + acc2 * 2^128 + acc1 * 2 ^ 64 + acc0 + acc0 * ( p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + 2^64 - 1)) / 2^64
= acc4 * 2^192 + acc3 * 2^128 + acc2 * 2^64 + acc1 + acc0 * p3 * 2^128 + acc0*p2*2^64+acc0*p1+acc0
= acc4 * 2^192 + (acc3 + acc0 * p3) * 2^128 + (acc2+acc0*p2) * 2^64 + acc0*p1 + acc1 + acc0
(carry1, acc1) = acc0 + acc1 + acc0 * p1
(carry2, acc2) = carry1 + acc2 + acc0 * p2
(carry3, acc3) = carry2 + acc3 + acc0 * p3
(carry4, acc4) = carry3 + acc4
acc5 = carry4
进位处理后,结果表示成 tmp = acc5 * 2^256 + acc4 * 2^192 + acc3 * 2^128 + acc2 * 2 ^ 64 + acc1
H = high64(acc0*2^32) 超出64位宽部分 , L = low64(acc0*2^32)
NIST P
p = 0xffffffff00000001 0000000000000000 00000000ffffffff ffffffffffffffff
= p3 * 2^192 + p1 * 2^64 + 2^64 - 1
p * acc0 = acc0 * p3 * 2^192 + acc0 * p1 * 2^64 + acc0 * 2^64 - acc0
(tmp + acc0 * p) / 2^64 = acc0 * p3 * 2^128 + acc0 * p1 + acc0 + acc4 * 2^192 + acc3 * 2^128 + acc2 * 2^64 + acc1
=acc4 * 2^192 + (acc0 * p3 + acc3) * 2^128 + acc2 * 2^64 + acc0 + acc1 + acc0* (2^32 - 1)
=acc4 * 2^192 + (acc0 * p3 + acc3) * 2^128 + acc2 * 2^64 + acc1 + acc0* 2^32
=acc4 * 2^192 + (acc0 * p3 + acc3) * 2^128 + (acc2 + H(acc0* 2^32))* 2^64 + acc1 + L(acc0* 2^32)
amd64 汇编表示为:
MOVQ acc0, AX
MOVQ acc0, t1
SHLQ $32, acc0 // L(acc0 * 2^32)
MULQ p256const1<>(SB) // acc0 * p3 = (DX, AX), DX为高64位
SHRQ $32, t1 // t1 = H(acc0 * 2^32)
ADDQ acc0, acc1 // (carry1, acc1) = acc1 + L(acc0 * 2^32)
ADCQ t1, acc2 // (carry2, acc2) = carry1 + acc2 + H(acc0 * 2^32)
ADCQ AX, acc3 // (carry3, acc3) = carry2 + acc3 + L(acc0 * p3)
ADCQ DX, acc4 // (carry4, acc4) = carry3 + acc4 + H(acc0 * p3)
ADCQ $0, acc5 // acc5 = carry4
XORQ acc0, acc0 // acc0 = 0
结果用五个64位寄存器表示:(acc5, acc4, acc3, acc2, acc1)
arm64 汇编表示为:
ADDS acc0<<32, acc1, acc1 // (carry1, acc1) = acc1 + L(acc0 * 2^32)
LSR $32, acc0, t0 // t0 = H(acc0 * 2^32)
MUL acc0, const1, t1 // t1 = L(acc0 * p3)
UMULH acc0, const1, acc0 // acc0 = H(acc0 * p3)
ADCS t0, acc2 // (carry2, acc2) = carry1 + acc2 + H(acc0 * 2^32)
ADCS t1, acc3 // (carry3, acc3) = carry2 + acc3 + L(acc0 * p3)
ADC $0, acc0 // acc0 = carry3 + H(acc0 * p3), arm64的实现,((acc0, acc4), acc3, acc2, acc1)表示第一次reduction的结果, 不像amd64那样使用(acc5, acc4, acc3, acc2, acc1)。
结果也用五个64位寄存器表示:(acc4, acc3, acc2, acc1) ,(acc0, 0, 0, 0)
另外,arm64中的ZR表示zero register,ADC $0, ZR, acc5 表示 acc5 = carry + 0 + 0。
SM2曲线
p = 0x fffffffeffffffff ffffffffffffffff ffffffff00000000 ffffffffffffffff
= (2^64 - 2^32 - 1) * 2^192 + (2^64 - 1) * 2^128 + (2^64 - 2^32) * 2^64 + (2^64 - 1)
= (2^64 - 2^32 - 1) * 2^192 + (2^64 - 1) * 2^128 + (2^64 - 2^32 + 1) * 2^64 - 1
= (2^64 - 2^32 - 1) * 2^192 + (2^64 ) * 2^128 + ( - 2^32 + 1) * 2^64 - 1
= (2^64 - 2^32 ) * 2^192 + ( - 2^32 + 1) * 2^64 - 1
= 2^256 + (-2^32) * 2^192 + (1-2^32)*2^64 - 1
p = p3 * 2^192 + p2*2^128 + p1 * 2^64 + 2^64 - 1
(tmp + acc0 * p) / 2^64 = acc4 * 2^192 + (acc3 + acc0*p3) * 2^128 + (acc2 + acc0*p2) * 2^64 + acc1 + acc0*p1 + acc0
amd64 汇编表示为:
MOVQ p256p<>+0x08(SB), AX
MULQ acc0
ADDQ acc0, acc1 // (carry1, acc1) = acc0 + acc1
ADCQ $0, DX // DX = carry1 + H(acc0 * p1)
ADDQ AX, acc1 // (carry2, acc1) = acc0 + acc1 + L(acc0*p1)
ADCQ $0, DX // DX = DX + carry2
MOVQ DX, t1 // t1 = H(acc0 * p1) + carry1 + carry2
MOVQ p256p<>+0x010(SB), AX
MULQ acc0
ADDQ t1, acc2 // (carry3, acc2) = t1 + acc2
ADCQ $0, DX // DX = carry3 + H(acc0 * p2)
ADDQ AX, acc2 // (carry4, acc2) = L(acc0 * p2) + L(t1 + acc2)
ADCQ $0, DX // DX = DX + carry4
MOVQ DX, t1 // t1 = H(acc0 * p2) + carry3 + carry4
MOVQ p256p<>+0x018(SB), AX
MULQ acc0
ADDQ t1, acc3 // (carry5, acc3) = t1 + acc3
ADCQ $0, DX // DX = carry5 + H(acc0 * p3)
ADDQ AX, acc3 // (carry6, acc3) = L(acc0 * p3) + L(t1 + acc3)
ADCQ DX, acc4 // (carry7, acc4) = acc4 + DX + carry6
ADCQ $0, acc5 // acc5 = carry7
XORQ acc0, acc0
arm64 汇编表示为
MUL const1, acc0, t0
ADDS t0, acc1, acc1 // (carry1, acc1) = acc1 + L(acc0*p1)
UMULH const1, acc0, y0 // y0 = H(acc0*p1)
MUL const2, acc0, t0
ADCS t0, acc2, acc2 // (carry2, acc2) = acc2 + L(acc0*p2)
UMULH const2, acc0, hlp0 // hlp0 = H(acc0*p2)
MUL const3, acc0, t0 // t0 = L(acc0*p3)
ADCS t0, acc3, acc3 // (carry3,acc3) = acc3 + L(acc0*p3)
UMULH const3, acc0, hlp1 // hlp1 = H(acc0*p3), 事实上不能用hlp1, 这个寄存器被p256PointAddAsm方法全局使用
ADC $0, acc4 // acc4 = carry3 + acc4
ADDS acc0, acc1, acc1 // (carry4, acc1) = acc0 + acc1 + L(acc0*p1)
ADCS y0, acc2, acc2 // (carry5, acc2) = carry4 + acc2 + L(acc0*p2) + H(acc0*p1)
ADCS hlp0, acc3, acc3 // (carry6, acc3) = carry5 + acc3 + L(acc0*p3) + H(acc0*p2)
ADC $0, hlp1, acc0 // acc0 = carry6 + H(acc0*p3)
手上没有arm64环境,只能依赖Travis CI检验代码,但是很慢,效率很低,不过改用arm64-graviton后好多了。
======
用加减替代乘法,但存在潜在风险,进位/借位处理太复杂,所以该实现已经被回滚
p*acc0 = acc0*2^256 -(acc0*2^32)*2^192 + (acc0 - acc0*2^32)*2^64 - acc0
(tmp + acc0 * p) / 2^64 = (acc4 * 2^256 + acc3 * 2^192 + acc2 * 2^128 + acc1 * 2 ^ 64 + acc0 + acc0*2^256 -(acc0*2^32)*2^192 + (acc0 - acc0*2^32)*2^64 - acc0? / 2^64
= (acc4+acc0)*2^192 + (acc3 - acc0*2^32) * 2^128 + acc2 * 2^64 + (acc1 + acc0 - acc0*2^32)
= (acc4+acc0)*2^192 + (acc3 - acc0*2^32) * 2^128 + (acc2 - H(acc0*2^32)) * 2^64 + (acc1 + acc0 - L(acc0*2^32))
(carry1, acc1) = acc0+acc1
acc2 = carry1 + acc2 // 有可能进位?有可能,当acc2 = 0xffffffffffffffff
(carry2, acc1) = acc1 - L
acc2 = acc2 - H - carry2 // 有可能借位?有可能,在acc0足够大,acc2足够小的情况下
(carry3, acc3) = acc0 + acc3
t1 = acc0 + carry3 //有可能进位吗?
(carry4, acc3) = acc3 - L
t1 = t1 - H - carry4 // 会有可能小于0吗?不可能
(carry5, acc3) = acc3 - acc0
t1 = t1 - carry5 // 会有可能小于0吗?不可能
(carry6, acc4) = acc4 + t1
acc5 = carry6
======
最后使用以下算法(主要就是一轮加法,一轮减法),相当有对称性:
acc4, acc3, acc2, acc1
+ acc0, 0, 0, acc0
- H(acc0*2^32) L(acc0*2^32) H(acc0*2^32) L(acc0*2^32)
MOVQ acc0, AX
MOVQ acc0, DX
SHLQ $32, AX
SHRQ $32, DX
ADDQ acc0, acc1
ADCQ $0, acc2
ADCQ $0, acc3
ADCQ acc0, acc4
ADCQ $0, acc5
SUBQ AX, acc1
SBBQ DX, acc2
SBBQ AX, acc3
SBBQ DX, acc4
SBBQ $0, acc5
tmp = tmp + X * Y1,按逐个64位字相加的原则:
tmp = tmp + X0*Y1
tmp = tmp + X1*Y1 * 2^64
tmp = tmp + X2*Y1 * 2^128
tmp = tmp + X3*Y1 * 2^192
(carry1, acc1) = acc1 + X0 * Y1
(carry2, acc2) = acc2 + carry1 + X1 * Y1
(carry3, acc3) = acc3 + carry2 + X2 * Y1
(carry4, acc4) = acc4 + carry3 + X3 * Y1
(carry5, acc5) = acc5 + carry4
acc0 = carry5
最后tmp表示成acc0*2^320 + acc5 * 2^256 + acc4 * 2^192 + acc3 * 2^128 + acc2 * 2 ^ 64 + acc1
计算(tmp + acc1 * p) / 2^64,这里p=p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + p0, 不管NIST P256还是SM2,p0 = 2^64 - 1,
所以我们扩展(tmp + acc1 * p) / 2^64
= (acc0*2^320 + acc5 * 2^256 + acc4 * 2^192 + acc3 * 2^128 + acc2 * 2 ^ 64 + acc1 + acc1 * ( p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + 2^64 - 1)) / 2^64
= acc0*2^256 + acc5 * 2^192 + (acc4 + acc1*p3)*2^128 + (acc3 + acc1*p2)*2^64 + acc1*p1+ acc2 + acc1
(carry1, acc2) = acc1 + acc2 + acc1 * p1
(carry2, acc3) = carry1 + acc3 + acc1 * p2
(carry3, acc4) = carry2 + acc4 + acc1 * p3
(carry4, acc5) = carry3 + acc5
acc0 = acc0 + carry4
进位处理后,结果表示成 tmp = acc0 * 2^256 + acc5 * 2^192 + acc4 * 2^128 + acc3 * 2 ^ 64 + acc2
tmp = tmp + X * Y2,按逐个64位字相加的原则:
tmp = tmp + X0*Y2
tmp = tmp + X1*Y2 * 2^64
tmp = tmp + X2*Y2 * 2^128
tmp = tmp + X3*Y2 * 2^192
(carry1, acc2) = acc2 + X0 * Y2
(carry2, acc3) = acc3 + carry1 + X1 * Y2
(carry3, acc4) = acc4 + carry2 + X2 * Y2
(carry4, acc5) = acc5 + carry3 + X3 * Y2
(carry5, acc0) = acc0 + carry4
acc1 = carry5
最后tmp表示成acc1*2^320 + acc0 * 2^256 + acc5 * 2^192 + acc4 * 2^128 + acc3 * 2 ^ 64 + acc2
计算(tmp + acc2 * p) / 2^64,这里p=p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + p0, 不管NIST P256还是SM2,p0 = 2^64 - 1,
所以我们扩展(tmp + acc2 * p) / 2^64
=(acc1*2^320 + acc0 * 2^256 + acc5 * 2^192 + acc4 * 2^128 + acc3 * 2 ^ 64 + acc2 + acc2 * ( p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + 2^64 - 1) ) / 2^64
=acc1*2^256 + acc0*2^192 + (acc5+acc2*p3)*2^128 + (acc4+acc2*p2)*2^64 + acc2 * p1 + acc3 + acc2
(carry1, acc3) = acc2 + acc3 + acc2 * p1
(carry2, acc4) = carry1 + acc4 + acc2 * p2
(carry3, acc5) = carry2 + acc5 + acc2 * p3
(carry4, acc0) = carry3 + acc0
acc1 = acc1 + carry4
进位处理后,结果表示成 tmp = acc1 * 2^256 + acc0 * 2^192 + acc5 * 2^128 + acc4 * 2 ^ 64 + acc3
并且和tmp相加, tmp = tmp + X * Y3,按逐个64位字相加的原则:
tmp = tmp + X0*Y3
tmp = tmp + X1*Y3 * 2^64
tmp = tmp + X2*Y3 * 2^128
tmp = tmp + X3*Y3 * 2^192
(carry1, acc3) = acc3 + X0 * Y3
(carry2, acc4) = acc4 + carry1 + X1 * Y3
(carry3, acc5) = acc5 + carry2 + X2 * Y3
(carry4, acc0) = acc0 + carry3 + X3 * Y3
(carry5, acc1) = acc1 + carry4
acc2 = carry5
最后tmp表示成acc2*2^320 + acc1 * 2^256 + acc0 * 2^192 + acc5 * 2^128 + acc4 * 2 ^ 64 + acc3
计算(tmp + acc3 * p) / 2^64,这里p=p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + p0, 不管NIST P256还是SM2,p0 = 2^64 - 1,
所以我们扩展(tmp + acc2 * p) / 2^64
=(acc2*2^320 + acc1 * 2^256 + acc0 * 2^192 + acc5 * 2^128 + acc4 * 2 ^ 64 + acc3 + acc3 * ( p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + 2^64 - 1) ) / 2^64
=acc2*2^256 + acc1*2^192 + (acc0+acc3*p3)*2^128 + (acc5+acc3*p2)*2^64 + acc3 * p1 + acc4 + acc3
(carry1, acc4) = acc3 + acc4 + acc3 * p1
(carry2, acc5) = carry1 + acc5 + acc3 * p2
(carry3, acc0) = carry2 + acc0 + acc3 * p3
(carry4, acc1) = carry3 + acc1
acc2 = acc2 + carry4
T = (acc2, acc1, acc0, acc5, acc4)
go test -v -short -bench . -run=^$ ./...
goos: linux
goarch: arm64
pkg: github.com/emmansun/gmsm/sm2
BenchmarkLessThan32_P256
BenchmarkLessThan32_P256-2 3698 279225 ns/op
BenchmarkLessThan32_P256SM2
BenchmarkLessThan32_P256SM2-2 4602 258525 ns/op
BenchmarkMoreThan32_P256
BenchmarkMoreThan32_P256-2 4365 274304 ns/op
BenchmarkMoreThan32_P256SM2
BenchmarkMoreThan32_P256SM2-2 4550 263296 ns/op
PASS
ok github.com/emmansun/gmsm/sm2 4.753s