diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 509398da979e8..0113efadd18e9 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1 +1,704 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td ++++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +@@ -739,20 +739,6 @@ + def CVT_f16x2_e5m2x2 : CVT_f16x2_fp8<"e5m2">; + } + +-def fpround_oneuse : PatFrag<(ops node:$a), (fpround node:$a), [{ +- return N->hasOneUse(); +-}]>; +- +-def : Pat<(v2bf16 (build_vector (bf16 (fpround_oneuse Float32Regs:$a)), +- (bf16 (fpround_oneuse Float32Regs:$b)))), +- (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>, +- Requires<[hasPTX<70>, hasSM<80>, hasBF16Math]>; +- +-def : Pat<(v2f16 (build_vector (f16 (fpround_oneuse Float32Regs:$a)), +- (f16 (fpround_oneuse Float32Regs:$b)))), +- (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>, +- Requires<[hasPTX<70>, hasSM<80>, useFP16Math]>; +- + //----------------------------------- + // Selection instructions (selp) + //----------------------------------- +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll +--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll ++++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll +@@ -204,7 +204,7 @@ + ; + ; SM80-LABEL: test_faddx2( + ; SM80: { +-; SM80-NEXT: .reg .b16 %rs<5>; ++; SM80-NEXT: .reg .b16 %rs<7>; + ; SM80-NEXT: .reg .b32 %r<4>; + ; SM80-NEXT: .reg .f32 %f<7>; + ; SM80-EMPTY: +@@ -216,16 +216,18 @@ + ; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1; + ; SM80-NEXT: cvt.f32.bf16 %f2, %rs4; + ; SM80-NEXT: add.rn.f32 %f3, %f2, %f1; ++; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3; + ; SM80-NEXT: cvt.f32.bf16 %f4, %rs1; + ; SM80-NEXT: cvt.f32.bf16 %f5, %rs3; + ; SM80-NEXT: add.rn.f32 %f6, %f5, %f4; +-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; ++; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6; ++; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5}; + ; SM80-NEXT: st.param.b32 [func_retval0], %r3; + ; SM80-NEXT: ret; + ; + ; SM80-FTZ-LABEL: test_faddx2( + ; SM80-FTZ: { +-; SM80-FTZ-NEXT: .reg .b16 %rs<5>; ++; SM80-FTZ-NEXT: .reg .b16 %rs<7>; + ; SM80-FTZ-NEXT: .reg .b32 %r<4>; + ; SM80-FTZ-NEXT: .reg .f32 %f<7>; + ; SM80-FTZ-EMPTY: +@@ -237,10 +239,12 @@ + ; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4; + ; SM80-FTZ-NEXT: add.rn.ftz.f32 %f3, %f2, %f1; ++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3; + ; SM80-FTZ-NEXT: add.rn.ftz.f32 %f6, %f5, %f4; +-; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; ++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6; ++; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5}; + ; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3; + ; SM80-FTZ-NEXT: ret; + ; +@@ -307,7 +311,7 @@ + ; + ; SM80-LABEL: test_fsubx2( + ; SM80: { +-; SM80-NEXT: .reg .b16 %rs<5>; ++; SM80-NEXT: .reg .b16 %rs<7>; + ; SM80-NEXT: .reg .b32 %r<4>; + ; SM80-NEXT: .reg .f32 %f<7>; + ; SM80-EMPTY: +@@ -319,16 +323,18 @@ + ; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1; + ; SM80-NEXT: cvt.f32.bf16 %f2, %rs4; + ; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1; ++; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3; + ; SM80-NEXT: cvt.f32.bf16 %f4, %rs1; + ; SM80-NEXT: cvt.f32.bf16 %f5, %rs3; + ; SM80-NEXT: sub.rn.f32 %f6, %f5, %f4; +-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; ++; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6; ++; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5}; + ; SM80-NEXT: st.param.b32 [func_retval0], %r3; + ; SM80-NEXT: ret; + ; + ; SM80-FTZ-LABEL: test_fsubx2( + ; SM80-FTZ: { +-; SM80-FTZ-NEXT: .reg .b16 %rs<5>; ++; SM80-FTZ-NEXT: .reg .b16 %rs<7>; + ; SM80-FTZ-NEXT: .reg .b32 %r<4>; + ; SM80-FTZ-NEXT: .reg .f32 %f<7>; + ; SM80-FTZ-EMPTY: +@@ -340,10 +346,12 @@ + ; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4; + ; SM80-FTZ-NEXT: sub.rn.ftz.f32 %f3, %f2, %f1; ++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3; + ; SM80-FTZ-NEXT: sub.rn.ftz.f32 %f6, %f5, %f4; +-; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; ++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6; ++; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5}; + ; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3; + ; SM80-FTZ-NEXT: ret; + ; +@@ -410,7 +418,7 @@ + ; + ; SM80-LABEL: test_fmulx2( + ; SM80: { +-; SM80-NEXT: .reg .b16 %rs<5>; ++; SM80-NEXT: .reg .b16 %rs<7>; + ; SM80-NEXT: .reg .b32 %r<4>; + ; SM80-NEXT: .reg .f32 %f<7>; + ; SM80-EMPTY: +@@ -422,16 +430,18 @@ + ; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1; + ; SM80-NEXT: cvt.f32.bf16 %f2, %rs4; + ; SM80-NEXT: mul.rn.f32 %f3, %f2, %f1; ++; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3; + ; SM80-NEXT: cvt.f32.bf16 %f4, %rs1; + ; SM80-NEXT: cvt.f32.bf16 %f5, %rs3; + ; SM80-NEXT: mul.rn.f32 %f6, %f5, %f4; +-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; ++; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6; ++; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5}; + ; SM80-NEXT: st.param.b32 [func_retval0], %r3; + ; SM80-NEXT: ret; + ; + ; SM80-FTZ-LABEL: test_fmulx2( + ; SM80-FTZ: { +-; SM80-FTZ-NEXT: .reg .b16 %rs<5>; ++; SM80-FTZ-NEXT: .reg .b16 %rs<7>; + ; SM80-FTZ-NEXT: .reg .b32 %r<4>; + ; SM80-FTZ-NEXT: .reg .f32 %f<7>; + ; SM80-FTZ-EMPTY: +@@ -443,10 +453,12 @@ + ; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4; + ; SM80-FTZ-NEXT: mul.rn.ftz.f32 %f3, %f2, %f1; ++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3; + ; SM80-FTZ-NEXT: mul.rn.ftz.f32 %f6, %f5, %f4; +-; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; ++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6; ++; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5}; + ; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3; + ; SM80-FTZ-NEXT: ret; + ; +@@ -513,7 +525,7 @@ + ; + ; SM80-LABEL: test_fdiv( + ; SM80: { +-; SM80-NEXT: .reg .b16 %rs<5>; ++; SM80-NEXT: .reg .b16 %rs<7>; + ; SM80-NEXT: .reg .b32 %r<4>; + ; SM80-NEXT: .reg .f32 %f<7>; + ; SM80-EMPTY: +@@ -525,16 +537,18 @@ + ; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1; + ; SM80-NEXT: cvt.f32.bf16 %f2, %rs4; + ; SM80-NEXT: div.rn.f32 %f3, %f2, %f1; ++; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3; + ; SM80-NEXT: cvt.f32.bf16 %f4, %rs1; + ; SM80-NEXT: cvt.f32.bf16 %f5, %rs3; + ; SM80-NEXT: div.rn.f32 %f6, %f5, %f4; +-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; ++; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6; ++; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5}; + ; SM80-NEXT: st.param.b32 [func_retval0], %r3; + ; SM80-NEXT: ret; + ; + ; SM80-FTZ-LABEL: test_fdiv( + ; SM80-FTZ: { +-; SM80-FTZ-NEXT: .reg .b16 %rs<5>; ++; SM80-FTZ-NEXT: .reg .b16 %rs<7>; + ; SM80-FTZ-NEXT: .reg .b32 %r<4>; + ; SM80-FTZ-NEXT: .reg .f32 %f<7>; + ; SM80-FTZ-EMPTY: +@@ -546,16 +560,18 @@ + ; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4; + ; SM80-FTZ-NEXT: div.rn.ftz.f32 %f3, %f2, %f1; ++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3; + ; SM80-FTZ-NEXT: div.rn.ftz.f32 %f6, %f5, %f4; +-; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; ++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6; ++; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5}; + ; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3; + ; SM80-FTZ-NEXT: ret; + ; + ; SM90-LABEL: test_fdiv( + ; SM90: { +-; SM90-NEXT: .reg .b16 %rs<5>; ++; SM90-NEXT: .reg .b16 %rs<7>; + ; SM90-NEXT: .reg .b32 %r<4>; + ; SM90-NEXT: .reg .f32 %f<7>; + ; SM90-EMPTY: +@@ -567,10 +583,12 @@ + ; SM90-NEXT: mov.b32 {%rs3, %rs4}, %r1; + ; SM90-NEXT: cvt.f32.bf16 %f2, %rs4; + ; SM90-NEXT: div.rn.f32 %f3, %f2, %f1; ++; SM90-NEXT: cvt.rn.bf16.f32 %rs5, %f3; + ; SM90-NEXT: cvt.f32.bf16 %f4, %rs1; + ; SM90-NEXT: cvt.f32.bf16 %f5, %rs3; + ; SM90-NEXT: div.rn.f32 %f6, %f5, %f4; +-; SM90-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; ++; SM90-NEXT: cvt.rn.bf16.f32 %rs6, %f6; ++; SM90-NEXT: mov.b32 %r3, {%rs6, %rs5}; + ; SM90-NEXT: st.param.b32 [func_retval0], %r3; + ; SM90-NEXT: ret; + %r = fdiv <2 x bfloat> %a, %b +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll +--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll ++++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll +@@ -13,7 +13,9 @@ + ; CHECK-DAG: cvt.f32.bf16 [[AF1:%f[0-9]+]], [[A1]]; + ; CHECK-DAG: sin.approx.f32 [[RF0:%f[0-9]+]], [[AF0]]; + ; CHECK-DAG: sin.approx.f32 [[RF1:%f[0-9]+]], [[AF1]]; +-; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]] ++; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]]; ++; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]]; ++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; + define <2 x bfloat> @test_sin(<2 x bfloat> %a) #0 #1 { +@@ -28,7 +30,9 @@ + ; CHECK-DAG: cvt.f32.bf16 [[AF1:%f[0-9]+]], [[A1]]; + ; CHECK-DAG: cos.approx.f32 [[RF0:%f[0-9]+]], [[AF0]]; + ; CHECK-DAG: cos.approx.f32 [[RF1:%f[0-9]+]], [[AF1]]; +-; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]] ++; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]]; ++; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]]; ++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; + define <2 x bfloat> @test_cos(<2 x bfloat> %a) #0 #1 { +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll +--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll ++++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll +@@ -26,7 +26,9 @@ + ; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]] + ; SM80-DAG: add.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], 0f3F800000; + ; SM80-DAG: add.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], 0f40000000; +-; SM80-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR0]], [[FR1]]; ++; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]] ++; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]] ++; SM80-DAG: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; + ; CHECK-NEXT: st.param.b32 [func_retval0], [[R]]; + ; CHECK-NEXT: ret; +@@ -66,7 +68,9 @@ + ; SM80-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]]; + ; SM80-DAG: sub.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]]; + ; SM80-DAG: sub.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]]; +-; SM80-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR0]], [[FR1]]; ++; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]]; ++; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]]; ++; SM80: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}; + + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; +@@ -89,7 +93,9 @@ + ; SM80-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]]; + ; SM80-DAG: mul.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]]; + ; SM80-DAG: mul.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]]; +-; SM80-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR0]], [[FR1]]; ++; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]]; ++; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]]; ++; SM80: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}; + + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; +@@ -110,7 +116,9 @@ + ; CHECK-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]]; + ; CHECK-DAG: div.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]]; + ; CHECK-DAG: div.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]]; +-; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR0]], [[FR1]]; ++; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]]; ++; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]]; ++; CHECK-NEXT: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; CHECK-NEXT: st.param.b32 [func_retval0], [[R]]; + ; CHECK-NEXT: ret; + +@@ -279,7 +287,9 @@ + + ; CHECK-LABEL: test_fptrunc_2xfloat( + ; CHECK: ld.param.v2.f32 {[[A0:%f[0-9]+]], [[A1:%f[0-9]+]]}, [test_fptrunc_2xfloat_param_0]; +-; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[A0]], [[A1]]; ++; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[A0]]; ++; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[A1]]; ++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; + define <2 x bfloat> @test_fptrunc_2xfloat(<2 x float> %a) #0 { +@@ -349,7 +359,9 @@ + ; CHECK-DAG: cvt.f32.bf16 [[AF1:%f[0-9]+]], [[A1]]; + ; CHECK-DAG: sqrt.rn.f32 [[RF0:%f[0-9]+]], [[AF0]]; + ; CHECK-DAG: sqrt.rn.f32 [[RF1:%f[0-9]+]], [[AF1]]; +-; CHECK-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]]; ++; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]]; ++; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]]; ++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; + define <2 x bfloat> @test_sqrt(<2 x bfloat> %a) #0 { +@@ -424,7 +436,9 @@ + ; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]]; + ; SM80-DAG: cvt.rmi.f32.f32 [[RF0:%f[0-9]+]], [[FA0]]; + ; SM80-DAG: cvt.rmi.f32.f32 [[RF1:%f[0-9]+]], [[FA1]]; +-; SM80: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]]; ++; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]]; ++; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]]; ++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; + define <2 x bfloat> @test_floor(<2 x bfloat> %a) #0 { +@@ -441,7 +455,9 @@ + ; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]]; + ; SM80-DAG: cvt.rpi.f32.f32 [[RF0:%f[0-9]+]], [[FA0]]; + ; SM80-DAG: cvt.rpi.f32.f32 [[RF1:%f[0-9]+]], [[FA1]]; +-; SM80: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]]; ++; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]]; ++; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]]; ++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; + define <2 x bfloat> @test_ceil(<2 x bfloat> %a) #0 { +@@ -454,7 +470,7 @@ + ; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]]; + ; SM90: cvt.rzi.bf16.bf16 [[R1:%rs[0-9]+]], [[A1]]; + ; SM90: cvt.rzi.bf16.bf16 [[R0:%rs[0-9]+]], [[A0]]; +-; SM90: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} ++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; + define <2 x bfloat> @test_trunc(<2 x bfloat> %a) #0 { +@@ -467,7 +483,7 @@ + ; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]]; + ; SM90: cvt.rni.bf16.bf16 [[R1:%rs[0-9]+]], [[A1]]; + ; SM90: cvt.rni.bf16.bf16 [[R0:%rs[0-9]+]], [[A0]]; +-; SM90: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} ++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; + define <2 x bfloat> @test_rint(<2 x bfloat> %a) #0 { +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/convert-sm80.ll b/llvm/test/CodeGen/NVPTX/convert-sm80.ll +--- a/llvm/test/CodeGen/NVPTX/convert-sm80.ll ++++ b/llvm/test/CodeGen/NVPTX/convert-sm80.ll +@@ -1,70 +1,41 @@ +-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 + ; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | FileCheck %s + ; RUN: %if ptxas-11.0 %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | %ptxas-verify -arch=sm_80 %} + + ++; CHECK-LABEL: cvt_rn_bf16x2_f32 + define <2 x bfloat> @cvt_rn_bf16x2_f32(float %f1, float %f2) { +-; CHECK-LABEL: cvt_rn_bf16x2_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_bf16x2_f32_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rn_bf16x2_f32_param_1]; +-; CHECK-NEXT: cvt.rn.bf16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %f1, float %f2) +- ret <2 x bfloat> %val ++ ++; CHECK: cvt.rn.bf16x2.f32 ++ %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %f1, float %f2); ++ ++ret <2 x bfloat> %val + } + ++; CHECK-LABEL: cvt_rn_relu_bf16x2_f32 + define <2 x bfloat> @cvt_rn_relu_bf16x2_f32(float %f1, float %f2) { +-; CHECK-LABEL: cvt_rn_relu_bf16x2_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_relu_bf16x2_f32_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rn_relu_bf16x2_f32_param_1]; +-; CHECK-NEXT: cvt.rn.relu.bf16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %f1, float %f2) +- ret <2 x bfloat> %val ++ ++; CHECK: cvt.rn.relu.bf16x2.f32 ++%val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %f1, float %f2); ++ ++ret <2 x bfloat> %val + } + ++; CHECK-LABEL: cvt_rz_bf16x2_f32 + define <2 x bfloat> @cvt_rz_bf16x2_f32(float %f1, float %f2) { +-; CHECK-LABEL: cvt_rz_bf16x2_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_bf16x2_f32_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rz_bf16x2_f32_param_1]; +-; CHECK-NEXT: cvt.rz.bf16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %f1, float %f2) +- ret <2 x bfloat> %val ++ ++; CHECK: cvt.rz.bf16x2.f32 ++ %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %f1, float %f2); ++ ++ret <2 x bfloat> %val + } + ++; CHECK-LABEL: cvt_rz_relu_bf16x2_f32 + define <2 x bfloat> @cvt_rz_relu_bf16x2_f32(float %f1, float %f2) { +-; CHECK-LABEL: cvt_rz_relu_bf16x2_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_relu_bf16x2_f32_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rz_relu_bf16x2_f32_param_1]; +-; CHECK-NEXT: cvt.rz.relu.bf16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %f1, float %f2) +- ret <2 x bfloat> %val ++ ++; CHECK: cvt.rz.relu.bf16x2.f32 ++%val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %f1, float %f2); ++ ++ret <2 x bfloat> %val + } + + declare <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float, float) +@@ -72,68 +43,40 @@ + declare <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float, float) + declare <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float, float) + ++; CHECK-LABEL: cvt_rn_f16x2_f32 + define <2 x half> @cvt_rn_f16x2_f32(float %f1, float %f2) { +-; CHECK-LABEL: cvt_rn_f16x2_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_f16x2_f32_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rn_f16x2_f32_param_1]; +-; CHECK-NEXT: cvt.rn.f16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %val = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %f1, float %f2) +- ret <2 x half> %val ++ ++; CHECK: cvt.rn.f16x2.f32 ++ %val = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %f1, float %f2); ++ ++ret <2 x half> %val + } + ++; CHECK-LABEL: cvt_rn_relu_f16x2_f32 + define <2 x half> @cvt_rn_relu_f16x2_f32(float %f1, float %f2) { +-; CHECK-LABEL: cvt_rn_relu_f16x2_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_relu_f16x2_f32_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rn_relu_f16x2_f32_param_1]; +-; CHECK-NEXT: cvt.rn.relu.f16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %val = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %f1, float %f2) +- ret <2 x half> %val ++ ++; CHECK: cvt.rn.relu.f16x2.f32 ++%val = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %f1, float %f2); ++ ++ret <2 x half> %val + } + ++; CHECK-LABEL: cvt_rz_f16x2_f32 + define <2 x half> @cvt_rz_f16x2_f32(float %f1, float %f2) { +-; CHECK-LABEL: cvt_rz_f16x2_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_f16x2_f32_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rz_f16x2_f32_param_1]; +-; CHECK-NEXT: cvt.rz.f16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %val = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %f1, float %f2) +- ret <2 x half> %val ++ ++; CHECK: cvt.rz.f16x2.f32 ++ %val = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %f1, float %f2); ++ ++ret <2 x half> %val + } + ++; CHECK-LABEL: cvt_rz_relu_f16x2_f32 + define <2 x half> @cvt_rz_relu_f16x2_f32(float %f1, float %f2) { +-; CHECK-LABEL: cvt_rz_relu_f16x2_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_relu_f16x2_f32_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rz_relu_f16x2_f32_param_1]; +-; CHECK-NEXT: cvt.rz.relu.f16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %val = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %f1, float %f2) +- ret <2 x half> %val ++ ++; CHECK: cvt.rz.relu.f16x2.f32 ++%val = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %f1, float %f2); ++ ++ret <2 x half> %val + } + + declare <2 x half> @llvm.nvvm.ff2f16x2.rn(float, float) +@@ -141,64 +84,40 @@ + declare <2 x half> @llvm.nvvm.ff2f16x2.rz(float, float) + declare <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float, float) + ++; CHECK-LABEL: cvt_rn_bf16_f32 + define bfloat @cvt_rn_bf16_f32(float %f1) { +-; CHECK-LABEL: cvt_rn_bf16_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b16 %rs<2>; +-; CHECK-NEXT: .reg .f32 %f<2>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_bf16_f32_param_0]; +-; CHECK-NEXT: cvt.rn.bf16.f32 %rs1, %f1; +-; CHECK-NEXT: st.param.b16 [func_retval0], %rs1; +-; CHECK-NEXT: ret; +- %val = call bfloat @llvm.nvvm.f2bf16.rn(float %f1) +- ret bfloat %val ++ ++; CHECK: cvt.rn.bf16.f32 ++ %val = call bfloat @llvm.nvvm.f2bf16.rn(float %f1); ++ ++ret bfloat %val + } + ++; CHECK-LABEL: cvt_rn_relu_bf16_f32 + define bfloat @cvt_rn_relu_bf16_f32(float %f1) { +-; CHECK-LABEL: cvt_rn_relu_bf16_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b16 %rs<2>; +-; CHECK-NEXT: .reg .f32 %f<2>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_relu_bf16_f32_param_0]; +-; CHECK-NEXT: cvt.rn.relu.bf16.f32 %rs1, %f1; +-; CHECK-NEXT: st.param.b16 [func_retval0], %rs1; +-; CHECK-NEXT: ret; +- %val = call bfloat @llvm.nvvm.f2bf16.rn.relu(float %f1) +- ret bfloat %val ++ ++; CHECK: cvt.rn.relu.bf16.f32 ++%val = call bfloat @llvm.nvvm.f2bf16.rn.relu(float %f1); ++ ++ret bfloat %val + } + ++; CHECK-LABEL: cvt_rz_bf16_f32 + define bfloat @cvt_rz_bf16_f32(float %f1) { +-; CHECK-LABEL: cvt_rz_bf16_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b16 %rs<2>; +-; CHECK-NEXT: .reg .f32 %f<2>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_bf16_f32_param_0]; +-; CHECK-NEXT: cvt.rz.bf16.f32 %rs1, %f1; +-; CHECK-NEXT: st.param.b16 [func_retval0], %rs1; +-; CHECK-NEXT: ret; +- %val = call bfloat @llvm.nvvm.f2bf16.rz(float %f1) +- ret bfloat %val ++ ++; CHECK: cvt.rz.bf16.f32 ++ %val = call bfloat @llvm.nvvm.f2bf16.rz(float %f1); ++ ++ret bfloat %val + } + ++; CHECK-LABEL: cvt_rz_relu_bf16_f32 + define bfloat @cvt_rz_relu_bf16_f32(float %f1) { +-; CHECK-LABEL: cvt_rz_relu_bf16_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b16 %rs<2>; +-; CHECK-NEXT: .reg .f32 %f<2>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_relu_bf16_f32_param_0]; +-; CHECK-NEXT: cvt.rz.relu.bf16.f32 %rs1, %f1; +-; CHECK-NEXT: st.param.b16 [func_retval0], %rs1; +-; CHECK-NEXT: ret; +- %val = call bfloat @llvm.nvvm.f2bf16.rz.relu(float %f1) +- ret bfloat %val ++ ++; CHECK: cvt.rz.relu.bf16.f32 ++%val = call bfloat @llvm.nvvm.f2bf16.rz.relu(float %f1); ++ ++ret bfloat %val + } + + declare bfloat @llvm.nvvm.f2bf16.rn(float) +@@ -206,58 +125,13 @@ + declare bfloat @llvm.nvvm.f2bf16.rz(float) + declare bfloat @llvm.nvvm.f2bf16.rz.relu(float) + ++; CHECK-LABEL: cvt_rna_tf32_f32 + define i32 @cvt_rna_tf32_f32(float %f1) { +-; CHECK-LABEL: cvt_rna_tf32_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<2>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rna_tf32_f32_param_0]; +-; CHECK-NEXT: cvt.rna.tf32.f32 %r1, %f1; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %val = call i32 @llvm.nvvm.f2tf32.rna(float %f1) +- ret i32 %val +-} + +-declare i32 @llvm.nvvm.f2tf32.rna(float) ++; CHECK: cvt.rna.tf32.f32 ++ %val = call i32 @llvm.nvvm.f2tf32.rna(float %f1); + +- +-define <2 x bfloat> @fold_ff2bf16x2(float %a, float %b) { +-; CHECK-LABEL: fold_ff2bf16x2( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [fold_ff2bf16x2_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [fold_ff2bf16x2_param_1]; +-; CHECK-NEXT: cvt.rn.bf16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %ah = fptrunc float %a to bfloat +- %bh = fptrunc float %b to bfloat +- %v0 = insertelement <2 x bfloat> poison, bfloat %ah, i64 0 +- %v1 = insertelement <2 x bfloat> %v0, bfloat %bh, i64 1 +- ret <2 x bfloat> %v1 +-} +- +-define <2 x half> @fold_ff2f16x2(float %a, float %b) { +-; CHECK-LABEL: fold_ff2f16x2( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [fold_ff2f16x2_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [fold_ff2f16x2_param_1]; +-; CHECK-NEXT: cvt.rn.f16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %ah = fptrunc float %a to half +- %bh = fptrunc float %b to half +- %v0 = insertelement <2 x half> poison, half %ah, i64 0 +- %v1 = insertelement <2 x half> %v0, half %bh, i64 1 +- ret <2 x half> %v1 ++ret i32 %val + } ++ ++declare i32 @llvm.nvvm.f2tf32.rna(float) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index ad42b42a95f79..63355e174793a 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "03730cdd3d10c5270fe436777a37d50b0838a3bf" - LLVM_SHA256 = "54d843249c75b200f7bf9b7947079fe16fa0b657c4aee4abdde4ac05a9cd5f84" + LLVM_COMMIT = "b3134fa2338388adf8cfb2d77339d0b042eab9f6" + LLVM_SHA256 = "b6024606092290b0838735c26ad1c5c239b3e931136b420af8680e3a1156e759" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index b69c2c9f92c8f..5d7f27f6cad1e 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,114 +1,982 @@ -diff --git a/shardy/dialect/sdy/ir/ops.td b/shardy/dialect/sdy/ir/ops.td -index dcf2c77..63ecbab 100644 ---- a/shardy/dialect/sdy/ir/ops.td -+++ b/shardy/dialect/sdy/ir/ops.td -@@ -197,7 +197,7 @@ def Sdy_ShardingGroupOp : Sdy_Op<"sharding_group", - // Op is non-pure since it modifies the internal representation of the - // sharding group. - []>{ -- let summary = "Constrains tensors in the group to have the same sharding."; -+ let summary = "Sharding group operation"; - let description = [{ - This op provides an interface to assign tensors to sharding groups ( - groups of tensors that will be enforced to have identical shardings). -diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc -index b07398e..24b5f26 100644 ---- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc -+++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc -@@ -468,11 +468,29 @@ OpShardingRuleAttr createOpShardingRule(Operation* op, - // - // Operands: [operand, iota, init_val (scalar), init_arg (scalar)] - // Results: [values, indices] -+ ArrayRef inputShape = -+ getTensorShape(customCall.getOperand(0)); -+ ArrayRef resultShape = -+ getTensorShape(customCall.getResult(0)); -+ int64_t numInputs = 2, numResults = 2; -+ SmallVector operandDims(customCall->getNumOperands(), -+ kNullDim); -+ SmallVector resultDims(customCall->getNumResults(), -+ kNullDim); - return OpShardingRuleBuilder(customCall) - .addPointwiseIfDimSizesMatch( -- getTensorShape(customCall.getOperand(0)), -- getTensorShape(customCall.getResult(0)), -- /*alwaysAddFactor=*/false) -+ inputShape, resultShape, -+ /*alwaysAddFactor=*/false, -+ /*onMismatchFn=*/ -+ [&](int64_t dim, OpShardingRuleBuilder& builder) { -+ std::fill_n(operandDims.begin(), numInputs, dim); -+ resultDims.assign(numResults, kNullDim); -+ builder.addFactor(operandDims, resultDims, inputShape[dim]); -+ resultDims.assign(numResults, dim); -+ std::fill_n(operandDims.begin(), numInputs, kNullDim); -+ builder.addFactor(operandDims, resultDims, -+ resultShape[dim]); -+ }) - .build(); - } - // TODO(b/327191011): output unregistered op stats instead. -diff --git a/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir b/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir -index 0ff2642..768d43a 100644 ---- a/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir -+++ b/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir -@@ -259,7 +259,7 @@ func.func @custom_call_top2_of_2d(%arg0: tensor<16x8xf32>) -> (tensor<16x2xf32>, +diff --git a/shardy/dialect/sdy/ir/attrs.td b/shardy/dialect/sdy/ir/attrs.td +index 3e5cc84..028f5a2 100644 +--- a/shardy/dialect/sdy/ir/attrs.td ++++ b/shardy/dialect/sdy/ir/attrs.td +@@ -187,12 +187,6 @@ def Sdy_AxisRef : AttrDef { + static std::function + getMeshComparator(MeshAttr mesh); - // CHECK-LABEL: func @custom_call_approx_topk - func.func @custom_call_approx_topk(%arg0: tensor<16x4xf32>, %arg1: tensor<16x4xf32>, %arg2: tensor, %arg3: tensor) -> (tensor<16x2xf32>, tensor<16x2xf32>) { -- // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, k], [], [])->([i, l], [i, m]) {i=16, j=1, k=1, l=1, m=1}> -+ // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j], [], [])->([i, k], [i, k]) {i=16, j=4, k=2}> - %0:2 = stablehlo.custom_call @ApproxTopK(%arg0, %arg1, %arg2, %arg3) { - mhlo.backend_config = { - aggregate_to_topk = true, -@@ -274,7 +274,7 @@ func.func @custom_call_approx_topk(%arg0: tensor<16x4xf32>, %arg1: tensor<16x4xf +- // Returns a comparator that order axis names lexicographically. +- // 1. Compare the axis names lexicographically. +- // 2. For two axes with the same name, sub-axis is smaller than the full axis. +- // 3. Sort two sub-axes based on comparator of SubAxisInfoAttr. +- bool operator<(const AxisRefAttr &rhs) const; +- + std::string toString() const; - // CHECK-LABEL: func @custom_call_partial_reduce - func.func @custom_call_partial_reduce(%arg0: tensor<16x4xf32>, %arg1: tensor<16x4xf32>, %arg2: tensor, %arg3: tensor) -> (tensor<16x2xf32>, tensor<16x2xf32>) { -- // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, k], [], [])->([i, l], [i, m]) {i=16, j=1, k=1, l=1, m=1}> -+ // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j], [], [])->([i, k], [i, k]) {i=16, j=4, k=2}> - %0:2 = stablehlo.custom_call @PartialReduce(%arg0, %arg1, %arg2, %arg3) { - mhlo.backend_config = { - aggregate_to_topk = true, -@@ -289,7 +289,7 @@ func.func @custom_call_partial_reduce(%arg0: tensor<16x4xf32>, %arg1: tensor<16x + // Returns the size of this axis or sub-axis. +diff --git a/shardy/dialect/sdy/ir/dialect.cc b/shardy/dialect/sdy/ir/dialect.cc +index 12cd856..08a5453 100644 +--- a/shardy/dialect/sdy/ir/dialect.cc ++++ b/shardy/dialect/sdy/ir/dialect.cc +@@ -254,24 +254,6 @@ AxisRefAttr::getMeshComparator(MeshAttr mesh) { + }; + } - // CHECK-LABEL: func @custom_call_partial_reduce_string_backend_config - func.func @custom_call_partial_reduce_string_backend_config(%arg0: tensor<16x4xf32>, %arg1: tensor<16x4xf32>, %arg2: tensor, %arg3: tensor) -> (tensor<16x2xf32>, tensor<16x2xf32>) { -- // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, k], [], [])->([i, l], [i, m]) {i=16, j=1, k=1, l=1, m=1}> -+ // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j], [], [])->([i, k], [i, k]) {i=16, j=4, k=2}> - %0:2 = stablehlo.custom_call @PartialReduce(%arg0, %arg1, %arg2, %arg3) { - backend_config = "{\22log2_reduction\22: 5, \22reduction_dim\22: 1, \22to_apply_type\22: \22comparator\22, \22top_k\22: 2, \22recall_target\22: 0.950000}", - called_computations = [@top_k_gt_f32_comparator]} : +-bool AxisRefAttr::operator<(const AxisRefAttr& rhs) const { +- StringRef name = getName(); +- StringRef rhsName = rhs.getName(); +- if (name != rhsName) { +- return name < rhsName; +- } +- // Both axis-refs have the same name, if one is a sub-axis and the other +- // is the full axis, then the sub-axis is smaller. +- if (!getSubAxisInfo()) { +- return false; +- } +- if (!rhs.getSubAxisInfo()) { +- return true; +- } +- // Both axis-refs are sub-axes. +- return getSubAxisInfo() < rhs.getSubAxisInfo(); +-} +- + std::string AxisRefAttr::toString() const { + return strippedAttrString(*this, /*stripMnemonic=*/true); + } +diff --git a/shardy/dialect/sdy/ir/dialect_test.cc b/shardy/dialect/sdy/ir/dialect_test.cc +index 25f5eae..f975590 100644 +--- a/shardy/dialect/sdy/ir/dialect_test.cc ++++ b/shardy/dialect/sdy/ir/dialect_test.cc +@@ -154,24 +154,6 @@ TEST_F(DialectTest, AxisRefAttrOverlaps) { + checkOverlaps(createSubAxis("x", 1, 2), createSubAxis("x", 4, 2), false); + } + +-TEST_F(DialectTest, AxisRefAttrCompare) { +- auto compare = [](AxisRefAttr a, AxisRefAttr b) { +- EXPECT_TRUE(a < b); +- EXPECT_FALSE(b < a); +- EXPECT_FALSE(a < a); +- EXPECT_FALSE(b < b); +- }; +- +- compare(createAxis("x"), createAxis("y")); +- compare(createSubAxis("x", 1, 4), createSubAxis("y", 2, 4)); +- compare(createAxis("x"), createSubAxis("y", 2, 4)); +- compare(createSubAxis("x", 1, 4), createAxis("y")); +- compare(createSubAxis("x", 1, 4), createSubAxis("x", 2, 4)); +- compare(createSubAxis("x", 1, 4), createSubAxis("x", 1, 8)); +- compare(createSubAxis("x", 1, 4), createAxis("x")); +- compare(createSubAxis("x", 1, 4), createSubAxis("x", 2, 2)); +-} +- + // The test cases are the same as DialectTest.AxisRefAttrOverlaps. + TEST_F(DialectTest, AxisRefAttrGetPrefixWithoutOverlap) { + auto samePrefix = [](AxisRefAttr a, AxisRefAttr b) { +diff --git a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc +index 1dd94e6..6baf5bb 100644 +--- a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc ++++ b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc +@@ -13,12 +13,11 @@ See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +-#include + #include + #include + #include ++#include + +-#include "llvm/ADT/Hashing.h" + #include "llvm/ADT/STLExtras.h" + #include "llvm/ADT/SmallVector.h" + #include "mlir/Dialect/Func/IR/FuncOps.h" // IWYU pragma: keep +@@ -170,72 +169,7 @@ bool axisRefsOverlap(ArrayRef first, + return false; + } + +-struct FactorAxesPair { +- int64_t factorIndex = -1; +- ArrayRef axisRefs; +- +- FactorAxesPair(int64_t factorIndex, ArrayRef axisRefs) +- : factorIndex(factorIndex), axisRefs(axisRefs) {} +- +- FactorAxesPair() = default; +- +- bool operator<(const FactorAxesPair& rhs) const { +- if (factorIndex != rhs.factorIndex) { +- return factorIndex < rhs.factorIndex; +- } +- if (axisRefs.size() != rhs.axisRefs.size()) { +- return axisRefs.size() < rhs.axisRefs.size(); +- } +- for (auto [axisRef, rhsAxisRef] : llvm::zip_equal(axisRefs, rhs.axisRefs)) { +- if (axisRef != rhsAxisRef) { +- return axisRef < rhsAxisRef; +- } +- } +- return false; +- } +- +- bool operator==(const FactorAxesPair& rhs) const { +- return factorIndex == rhs.factorIndex && axisRefs == rhs.axisRefs; +- } +-}; +- +-struct FactorAxesPairInfo : public llvm::DenseMapInfo { +- static unsigned getHashValue(const FactorAxesPair& m) { +- return llvm::hash_combine(m.factorIndex, m.axisRefs); +- } +- static bool isEqual(const FactorAxesPair& lhs, const FactorAxesPair& rhs) { +- return lhs == rhs; +- } +- +- static inline FactorAxesPair getEmptyKey() { return FactorAxesPair(); } +- +- static inline FactorAxesPair getTombstoneKey() { +- return FactorAxesPair(/*factorIndex=*/-2, /*axisRefs=*/{}); +- } +-}; +- +-struct FactorAxesAssignmentCandidate { +- FactorAxesPair factorAxes; +- int64_t count = 0; +- +- FactorAxesAssignmentCandidate(FactorAxesPair factorAxes, int64_t count) +- : factorAxes(factorAxes), count(count) {} +- +- FactorAxesAssignmentCandidate() = default; +- +- bool operator<(const FactorAxesAssignmentCandidate& rhs) const { +- if (count != rhs.count) { +- return count < rhs.count; +- } +- // TODO(enver): Tie-break based on sharded tensor sizes, instead. +- return rhs.factorAxes < factorAxes; +- } +- +- void assignTo(AxesPerFactor& axesPerFactor) { +- axesPerFactor[factorAxes.factorIndex] = +- llvm::to_vector(factorAxes.axisRefs); +- } +-}; ++using FactorAxesPair = std::pair>; + + // Broadly the algorithm is, at each iteration, to pick a {factor,axis} pair + // with the largest count from a list that is initialized with all the +@@ -246,8 +180,10 @@ struct FactorAxesAssignmentCandidate { + AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic( + const ShardingProjection& projection, int64_t numFactors) { + AxesPerFactor factorAxisRefs(numFactors); +- DenseMap factorAxesCounts; +- FactorAxesAssignmentCandidate bestFactorAxes; ++ SmallVector, int64_t>> factorAxesCounts( ++ numFactors); ++ int64_t maxCount = 0; ++ FactorAxesPair bestFactorAxes; + for (const TensorFactorShardings& tensorFactorSharding : + llvm::concat(projection.getOperands(), + projection.getResults())) { +@@ -256,10 +192,12 @@ AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic( + if (factorSharding.axisRefs.empty()) { + continue; + } +- FactorAxesPair factorAxes(factorIndex, factorSharding.axisRefs); +- bestFactorAxes = std::max( +- bestFactorAxes, FactorAxesAssignmentCandidate( +- factorAxes, ++factorAxesCounts[factorAxes])); ++ ArrayRef axisRefs = factorSharding.axisRefs; ++ int64_t axesCount = ++factorAxesCounts[factorIndex][axisRefs]; ++ if (axesCount > maxCount) { ++ maxCount = axesCount; ++ bestFactorAxes = FactorAxesPair(factorIndex, axisRefs); ++ } + } + } + +@@ -269,33 +207,38 @@ AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic( + // the count of [x] prefix will be two for this factor. + // TODO(enver): Assign an axis to a factor immediately if the count is more + // than floor(n/2) where n is the number of tensors. +- while (bestFactorAxes.count > 0) { +- bestFactorAxes.assignTo(factorAxisRefs); ++ BitVector unseenFactors(numFactors, true); ++ // TODO(enver): Optimize to mark unseen only the factors with an axis. ++ while (maxCount > 0) { ++ factorAxisRefs[bestFactorAxes.first] = ++ llvm::to_vector(bestFactorAxes.second); ++ unseenFactors.reset(bestFactorAxes.first); + // TODO(enver): Tie-breaking currently depends on the order of iteration. + // Consider some heuristic for breaking ties. + // Invalidate axes that overlaps with the picked one across all unseen + // factors. During the iteration, also find the new best. +- FactorAxesAssignmentCandidate nextBestFactorAxes; +- for (auto factorAxesCountIt = factorAxesCounts.begin(); +- factorAxesCountIt != factorAxesCounts.end(); factorAxesCountIt++) { +- const auto& [factorAxes, count] = *factorAxesCountIt; +- // TODO(enver): Relax the overlap check. We need to erase in case of an +- // overlap only if the factor indices appear together in any of the +- // operands or results. +- // TODO(enver): Use FactorAxesPair overlap api instead. +- if (factorAxes.factorIndex == bestFactorAxes.factorAxes.factorIndex || +- axisRefsOverlap(factorAxes.axisRefs, +- bestFactorAxes.factorAxes.axisRefs)) { +- // TODO(enver): Optimize to flip unseen if all the axes of the factor +- // have zero count. +- // Clear the count of overlapping axis, effectively erasing. +- // TODO(enver): Instead of removing from the list, trim the axisRefs, +- // to use the largest prefix that does not overlap with bestAxisRefs. +- factorAxesCounts.erase(factorAxesCountIt); +- continue; ++ maxCount = 0; ++ FactorAxesPair nextBestFactorAxes; ++ for (int factorIndex : unseenFactors.set_bits()) { ++ auto& axesCounts = factorAxesCounts[factorIndex]; ++ for (const auto& [axisRefs, count] : axesCounts) { ++ // TODO(enver): Relax the overlap check. We need to erase in case of an ++ // overlap only if the factor indices appear together in any of the ++ // operands or results. ++ if (axisRefsOverlap(bestFactorAxes.second, axisRefs)) { ++ // TODO(enver): Optimize to flip unseen if all the axes of the factor ++ // have zero count. ++ // Clear the count of overlapping axis, effectively erasing. ++ // TODO(enver): Instead of removing from the list, trim the axisRefs, ++ // to use the largest prefix that does not overlap with bestAxisRefs. ++ axesCounts[axisRefs] = 0; ++ continue; ++ } ++ if (count > maxCount) { ++ maxCount = count; ++ nextBestFactorAxes = FactorAxesPair(factorIndex, axisRefs); ++ } + } +- nextBestFactorAxes = std::max( +- nextBestFactorAxes, FactorAxesAssignmentCandidate(factorAxes, count)); + } + bestFactorAxes = nextBestFactorAxes; + } diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index 0d8fe66..509398d 100644 +index 509398d..0113efa 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1,13 +1 @@ +@@ -1 +1,704 @@ Auto generated patch. Do not edit or delete it, even if empty. --diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp ----- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp --+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp --@@ -1465,7 +1465,7 @@ -- : Suffix; -- -- ValueSet StructValues; --- StructType *StructTy; --+ StructType *StructTy = nullptr; -- Function *newFunction = constructFunctionDeclaration( -- inputs, outputs, EntryFreq, oldFunction->getName() + "." + SuffixToUse, -- StructValues, StructTy); ++diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td ++--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td ++@@ -739,20 +739,6 @@ ++ def CVT_f16x2_e5m2x2 : CVT_f16x2_fp8<"e5m2">; ++ } ++ ++-def fpround_oneuse : PatFrag<(ops node:$a), (fpround node:$a), [{ ++- return N->hasOneUse(); ++-}]>; ++- ++-def : Pat<(v2bf16 (build_vector (bf16 (fpround_oneuse Float32Regs:$a)), ++- (bf16 (fpround_oneuse Float32Regs:$b)))), ++- (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>, ++- Requires<[hasPTX<70>, hasSM<80>, hasBF16Math]>; ++- ++-def : Pat<(v2f16 (build_vector (f16 (fpround_oneuse Float32Regs:$a)), ++- (f16 (fpround_oneuse Float32Regs:$b)))), ++- (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>, ++- Requires<[hasPTX<70>, hasSM<80>, useFP16Math]>; ++- ++ //----------------------------------- ++ // Selection instructions (selp) ++ //----------------------------------- ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll ++--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll +++++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll ++@@ -204,7 +204,7 @@ ++ ; ++ ; SM80-LABEL: test_faddx2( ++ ; SM80: { ++-; SM80-NEXT: .reg .b16 %rs<5>; +++; SM80-NEXT: .reg .b16 %rs<7>; ++ ; SM80-NEXT: .reg .b32 %r<4>; ++ ; SM80-NEXT: .reg .f32 %f<7>; ++ ; SM80-EMPTY: ++@@ -216,16 +216,18 @@ ++ ; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1; ++ ; SM80-NEXT: cvt.f32.bf16 %f2, %rs4; ++ ; SM80-NEXT: add.rn.f32 %f3, %f2, %f1; +++; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3; ++ ; SM80-NEXT: cvt.f32.bf16 %f4, %rs1; ++ ; SM80-NEXT: cvt.f32.bf16 %f5, %rs3; ++ ; SM80-NEXT: add.rn.f32 %f6, %f5, %f4; ++-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; +++; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6; +++; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5}; ++ ; SM80-NEXT: st.param.b32 [func_retval0], %r3; ++ ; SM80-NEXT: ret; ++ ; ++ ; SM80-FTZ-LABEL: test_faddx2( ++ ; SM80-FTZ: { ++-; SM80-FTZ-NEXT: .reg .b16 %rs<5>; +++; SM80-FTZ-NEXT: .reg .b16 %rs<7>; ++ ; SM80-FTZ-NEXT: .reg .b32 %r<4>; ++ ; SM80-FTZ-NEXT: .reg .f32 %f<7>; ++ ; SM80-FTZ-EMPTY: ++@@ -237,10 +239,12 @@ ++ ; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1; ++ ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4; ++ ; SM80-FTZ-NEXT: add.rn.ftz.f32 %f3, %f2, %f1; +++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3; ++ ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1; ++ ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3; ++ ; SM80-FTZ-NEXT: add.rn.ftz.f32 %f6, %f5, %f4; ++-; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; +++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6; +++; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5}; ++ ; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3; ++ ; SM80-FTZ-NEXT: ret; ++ ; ++@@ -307,7 +311,7 @@ ++ ; ++ ; SM80-LABEL: test_fsubx2( ++ ; SM80: { ++-; SM80-NEXT: .reg .b16 %rs<5>; +++; SM80-NEXT: .reg .b16 %rs<7>; ++ ; SM80-NEXT: .reg .b32 %r<4>; ++ ; SM80-NEXT: .reg .f32 %f<7>; ++ ; SM80-EMPTY: ++@@ -319,16 +323,18 @@ ++ ; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1; ++ ; SM80-NEXT: cvt.f32.bf16 %f2, %rs4; ++ ; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1; +++; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3; ++ ; SM80-NEXT: cvt.f32.bf16 %f4, %rs1; ++ ; SM80-NEXT: cvt.f32.bf16 %f5, %rs3; ++ ; SM80-NEXT: sub.rn.f32 %f6, %f5, %f4; ++-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; +++; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6; +++; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5}; ++ ; SM80-NEXT: st.param.b32 [func_retval0], %r3; ++ ; SM80-NEXT: ret; ++ ; ++ ; SM80-FTZ-LABEL: test_fsubx2( ++ ; SM80-FTZ: { ++-; SM80-FTZ-NEXT: .reg .b16 %rs<5>; +++; SM80-FTZ-NEXT: .reg .b16 %rs<7>; ++ ; SM80-FTZ-NEXT: .reg .b32 %r<4>; ++ ; SM80-FTZ-NEXT: .reg .f32 %f<7>; ++ ; SM80-FTZ-EMPTY: ++@@ -340,10 +346,12 @@ ++ ; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1; ++ ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4; ++ ; SM80-FTZ-NEXT: sub.rn.ftz.f32 %f3, %f2, %f1; +++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3; ++ ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1; ++ ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3; ++ ; SM80-FTZ-NEXT: sub.rn.ftz.f32 %f6, %f5, %f4; ++-; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; +++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6; +++; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5}; ++ ; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3; ++ ; SM80-FTZ-NEXT: ret; ++ ; ++@@ -410,7 +418,7 @@ ++ ; ++ ; SM80-LABEL: test_fmulx2( ++ ; SM80: { ++-; SM80-NEXT: .reg .b16 %rs<5>; +++; SM80-NEXT: .reg .b16 %rs<7>; ++ ; SM80-NEXT: .reg .b32 %r<4>; ++ ; SM80-NEXT: .reg .f32 %f<7>; ++ ; SM80-EMPTY: ++@@ -422,16 +430,18 @@ ++ ; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1; ++ ; SM80-NEXT: cvt.f32.bf16 %f2, %rs4; ++ ; SM80-NEXT: mul.rn.f32 %f3, %f2, %f1; +++; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3; ++ ; SM80-NEXT: cvt.f32.bf16 %f4, %rs1; ++ ; SM80-NEXT: cvt.f32.bf16 %f5, %rs3; ++ ; SM80-NEXT: mul.rn.f32 %f6, %f5, %f4; ++-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; +++; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6; +++; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5}; ++ ; SM80-NEXT: st.param.b32 [func_retval0], %r3; ++ ; SM80-NEXT: ret; ++ ; ++ ; SM80-FTZ-LABEL: test_fmulx2( ++ ; SM80-FTZ: { ++-; SM80-FTZ-NEXT: .reg .b16 %rs<5>; +++; SM80-FTZ-NEXT: .reg .b16 %rs<7>; ++ ; SM80-FTZ-NEXT: .reg .b32 %r<4>; ++ ; SM80-FTZ-NEXT: .reg .f32 %f<7>; ++ ; SM80-FTZ-EMPTY: ++@@ -443,10 +453,12 @@ ++ ; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1; ++ ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4; ++ ; SM80-FTZ-NEXT: mul.rn.ftz.f32 %f3, %f2, %f1; +++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3; ++ ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1; ++ ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3; ++ ; SM80-FTZ-NEXT: mul.rn.ftz.f32 %f6, %f5, %f4; ++-; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; +++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6; +++; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5}; ++ ; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3; ++ ; SM80-FTZ-NEXT: ret; ++ ; ++@@ -513,7 +525,7 @@ ++ ; ++ ; SM80-LABEL: test_fdiv( ++ ; SM80: { ++-; SM80-NEXT: .reg .b16 %rs<5>; +++; SM80-NEXT: .reg .b16 %rs<7>; ++ ; SM80-NEXT: .reg .b32 %r<4>; ++ ; SM80-NEXT: .reg .f32 %f<7>; ++ ; SM80-EMPTY: ++@@ -525,16 +537,18 @@ ++ ; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1; ++ ; SM80-NEXT: cvt.f32.bf16 %f2, %rs4; ++ ; SM80-NEXT: div.rn.f32 %f3, %f2, %f1; +++; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3; ++ ; SM80-NEXT: cvt.f32.bf16 %f4, %rs1; ++ ; SM80-NEXT: cvt.f32.bf16 %f5, %rs3; ++ ; SM80-NEXT: div.rn.f32 %f6, %f5, %f4; ++-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; +++; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6; +++; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5}; ++ ; SM80-NEXT: st.param.b32 [func_retval0], %r3; ++ ; SM80-NEXT: ret; ++ ; ++ ; SM80-FTZ-LABEL: test_fdiv( ++ ; SM80-FTZ: { ++-; SM80-FTZ-NEXT: .reg .b16 %rs<5>; +++; SM80-FTZ-NEXT: .reg .b16 %rs<7>; ++ ; SM80-FTZ-NEXT: .reg .b32 %r<4>; ++ ; SM80-FTZ-NEXT: .reg .f32 %f<7>; ++ ; SM80-FTZ-EMPTY: ++@@ -546,16 +560,18 @@ ++ ; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1; ++ ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4; ++ ; SM80-FTZ-NEXT: div.rn.ftz.f32 %f3, %f2, %f1; +++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3; ++ ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1; ++ ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3; ++ ; SM80-FTZ-NEXT: div.rn.ftz.f32 %f6, %f5, %f4; ++-; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; +++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6; +++; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5}; ++ ; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3; ++ ; SM80-FTZ-NEXT: ret; ++ ; ++ ; SM90-LABEL: test_fdiv( ++ ; SM90: { ++-; SM90-NEXT: .reg .b16 %rs<5>; +++; SM90-NEXT: .reg .b16 %rs<7>; ++ ; SM90-NEXT: .reg .b32 %r<4>; ++ ; SM90-NEXT: .reg .f32 %f<7>; ++ ; SM90-EMPTY: ++@@ -567,10 +583,12 @@ ++ ; SM90-NEXT: mov.b32 {%rs3, %rs4}, %r1; ++ ; SM90-NEXT: cvt.f32.bf16 %f2, %rs4; ++ ; SM90-NEXT: div.rn.f32 %f3, %f2, %f1; +++; SM90-NEXT: cvt.rn.bf16.f32 %rs5, %f3; ++ ; SM90-NEXT: cvt.f32.bf16 %f4, %rs1; ++ ; SM90-NEXT: cvt.f32.bf16 %f5, %rs3; ++ ; SM90-NEXT: div.rn.f32 %f6, %f5, %f4; ++-; SM90-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; +++; SM90-NEXT: cvt.rn.bf16.f32 %rs6, %f6; +++; SM90-NEXT: mov.b32 %r3, {%rs6, %rs5}; ++ ; SM90-NEXT: st.param.b32 [func_retval0], %r3; ++ ; SM90-NEXT: ret; ++ %r = fdiv <2 x bfloat> %a, %b ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll ++--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll +++++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll ++@@ -13,7 +13,9 @@ ++ ; CHECK-DAG: cvt.f32.bf16 [[AF1:%f[0-9]+]], [[A1]]; ++ ; CHECK-DAG: sin.approx.f32 [[RF0:%f[0-9]+]], [[AF0]]; ++ ; CHECK-DAG: sin.approx.f32 [[RF1:%f[0-9]+]], [[AF1]]; ++-; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]] +++; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]]; +++; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]]; +++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} ++ ; CHECK: st.param.b32 [func_retval0], [[R]]; ++ ; CHECK: ret; ++ define <2 x bfloat> @test_sin(<2 x bfloat> %a) #0 #1 { ++@@ -28,7 +30,9 @@ ++ ; CHECK-DAG: cvt.f32.bf16 [[AF1:%f[0-9]+]], [[A1]]; ++ ; CHECK-DAG: cos.approx.f32 [[RF0:%f[0-9]+]], [[AF0]]; ++ ; CHECK-DAG: cos.approx.f32 [[RF1:%f[0-9]+]], [[AF1]]; ++-; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]] +++; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]]; +++; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]]; +++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} ++ ; CHECK: st.param.b32 [func_retval0], [[R]]; ++ ; CHECK: ret; ++ define <2 x bfloat> @test_cos(<2 x bfloat> %a) #0 #1 { ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll ++--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll +++++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll ++@@ -26,7 +26,9 @@ ++ ; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]] ++ ; SM80-DAG: add.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], 0f3F800000; ++ ; SM80-DAG: add.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], 0f40000000; ++-; SM80-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR0]], [[FR1]]; +++; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]] +++; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]] +++; SM80-DAG: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} ++ ; ++ ; CHECK-NEXT: st.param.b32 [func_retval0], [[R]]; ++ ; CHECK-NEXT: ret; ++@@ -66,7 +68,9 @@ ++ ; SM80-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]]; ++ ; SM80-DAG: sub.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]]; ++ ; SM80-DAG: sub.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]]; ++-; SM80-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR0]], [[FR1]]; +++; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]]; +++; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]]; +++; SM80: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}; ++ ++ ; CHECK: st.param.b32 [func_retval0], [[R]]; ++ ; CHECK: ret; ++@@ -89,7 +93,9 @@ ++ ; SM80-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]]; ++ ; SM80-DAG: mul.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]]; ++ ; SM80-DAG: mul.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]]; ++-; SM80-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR0]], [[FR1]]; +++; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]]; +++; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]]; +++; SM80: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}; ++ ++ ; CHECK: st.param.b32 [func_retval0], [[R]]; ++ ; CHECK: ret; ++@@ -110,7 +116,9 @@ ++ ; CHECK-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]]; ++ ; CHECK-DAG: div.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]]; ++ ; CHECK-DAG: div.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]]; ++-; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR0]], [[FR1]]; +++; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]]; +++; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]]; +++; CHECK-NEXT: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} ++ ; CHECK-NEXT: st.param.b32 [func_retval0], [[R]]; ++ ; CHECK-NEXT: ret; ++ ++@@ -279,7 +287,9 @@ ++ ++ ; CHECK-LABEL: test_fptrunc_2xfloat( ++ ; CHECK: ld.param.v2.f32 {[[A0:%f[0-9]+]], [[A1:%f[0-9]+]]}, [test_fptrunc_2xfloat_param_0]; ++-; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[A0]], [[A1]]; +++; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[A0]]; +++; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[A1]]; +++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} ++ ; CHECK: st.param.b32 [func_retval0], [[R]]; ++ ; CHECK: ret; ++ define <2 x bfloat> @test_fptrunc_2xfloat(<2 x float> %a) #0 { ++@@ -349,7 +359,9 @@ ++ ; CHECK-DAG: cvt.f32.bf16 [[AF1:%f[0-9]+]], [[A1]]; ++ ; CHECK-DAG: sqrt.rn.f32 [[RF0:%f[0-9]+]], [[AF0]]; ++ ; CHECK-DAG: sqrt.rn.f32 [[RF1:%f[0-9]+]], [[AF1]]; ++-; CHECK-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]]; +++; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]]; +++; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]]; +++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} ++ ; CHECK: st.param.b32 [func_retval0], [[R]]; ++ ; CHECK: ret; ++ define <2 x bfloat> @test_sqrt(<2 x bfloat> %a) #0 { ++@@ -424,7 +436,9 @@ ++ ; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]]; ++ ; SM80-DAG: cvt.rmi.f32.f32 [[RF0:%f[0-9]+]], [[FA0]]; ++ ; SM80-DAG: cvt.rmi.f32.f32 [[RF1:%f[0-9]+]], [[FA1]]; ++-; SM80: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]]; +++; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]]; +++; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]]; +++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} ++ ; CHECK: st.param.b32 [func_retval0], [[R]]; ++ ; CHECK: ret; ++ define <2 x bfloat> @test_floor(<2 x bfloat> %a) #0 { ++@@ -441,7 +455,9 @@ ++ ; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]]; ++ ; SM80-DAG: cvt.rpi.f32.f32 [[RF0:%f[0-9]+]], [[FA0]]; ++ ; SM80-DAG: cvt.rpi.f32.f32 [[RF1:%f[0-9]+]], [[FA1]]; ++-; SM80: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]]; +++; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]]; +++; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]]; +++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} ++ ; CHECK: st.param.b32 [func_retval0], [[R]]; ++ ; CHECK: ret; ++ define <2 x bfloat> @test_ceil(<2 x bfloat> %a) #0 { ++@@ -454,7 +470,7 @@ ++ ; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]]; ++ ; SM90: cvt.rzi.bf16.bf16 [[R1:%rs[0-9]+]], [[A1]]; ++ ; SM90: cvt.rzi.bf16.bf16 [[R0:%rs[0-9]+]], [[A0]]; ++-; SM90: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} +++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} ++ ; CHECK: st.param.b32 [func_retval0], [[R]]; ++ ; CHECK: ret; ++ define <2 x bfloat> @test_trunc(<2 x bfloat> %a) #0 { ++@@ -467,7 +483,7 @@ ++ ; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]]; ++ ; SM90: cvt.rni.bf16.bf16 [[R1:%rs[0-9]+]], [[A1]]; ++ ; SM90: cvt.rni.bf16.bf16 [[R0:%rs[0-9]+]], [[A0]]; ++-; SM90: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} +++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} ++ ; CHECK: st.param.b32 [func_retval0], [[R]]; ++ ; CHECK: ret; ++ define <2 x bfloat> @test_rint(<2 x bfloat> %a) #0 { ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/convert-sm80.ll b/llvm/test/CodeGen/NVPTX/convert-sm80.ll ++--- a/llvm/test/CodeGen/NVPTX/convert-sm80.ll +++++ b/llvm/test/CodeGen/NVPTX/convert-sm80.ll ++@@ -1,70 +1,41 @@ ++-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 ++ ; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | FileCheck %s ++ ; RUN: %if ptxas-11.0 %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | %ptxas-verify -arch=sm_80 %} ++ ++ +++; CHECK-LABEL: cvt_rn_bf16x2_f32 ++ define <2 x bfloat> @cvt_rn_bf16x2_f32(float %f1, float %f2) { ++-; CHECK-LABEL: cvt_rn_bf16x2_f32( ++-; CHECK: { ++-; CHECK-NEXT: .reg .b32 %r<2>; ++-; CHECK-NEXT: .reg .f32 %f<3>; ++-; CHECK-EMPTY: ++-; CHECK-NEXT: // %bb.0: ++-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_bf16x2_f32_param_0]; ++-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rn_bf16x2_f32_param_1]; ++-; CHECK-NEXT: cvt.rn.bf16x2.f32 %r1, %f1, %f2; ++-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; ++-; CHECK-NEXT: ret; ++- %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %f1, float %f2) ++- ret <2 x bfloat> %val +++ +++; CHECK: cvt.rn.bf16x2.f32 +++ %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %f1, float %f2); +++ +++ret <2 x bfloat> %val ++ } ++ +++; CHECK-LABEL: cvt_rn_relu_bf16x2_f32 ++ define <2 x bfloat> @cvt_rn_relu_bf16x2_f32(float %f1, float %f2) { ++-; CHECK-LABEL: cvt_rn_relu_bf16x2_f32( ++-; CHECK: { ++-; CHECK-NEXT: .reg .b32 %r<2>; ++-; CHECK-NEXT: .reg .f32 %f<3>; ++-; CHECK-EMPTY: ++-; CHECK-NEXT: // %bb.0: ++-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_relu_bf16x2_f32_param_0]; ++-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rn_relu_bf16x2_f32_param_1]; ++-; CHECK-NEXT: cvt.rn.relu.bf16x2.f32 %r1, %f1, %f2; ++-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; ++-; CHECK-NEXT: ret; ++- %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %f1, float %f2) ++- ret <2 x bfloat> %val +++ +++; CHECK: cvt.rn.relu.bf16x2.f32 +++%val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %f1, float %f2); +++ +++ret <2 x bfloat> %val ++ } ++ +++; CHECK-LABEL: cvt_rz_bf16x2_f32 ++ define <2 x bfloat> @cvt_rz_bf16x2_f32(float %f1, float %f2) { ++-; CHECK-LABEL: cvt_rz_bf16x2_f32( ++-; CHECK: { ++-; CHECK-NEXT: .reg .b32 %r<2>; ++-; CHECK-NEXT: .reg .f32 %f<3>; ++-; CHECK-EMPTY: ++-; CHECK-NEXT: // %bb.0: ++-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_bf16x2_f32_param_0]; ++-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rz_bf16x2_f32_param_1]; ++-; CHECK-NEXT: cvt.rz.bf16x2.f32 %r1, %f1, %f2; ++-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; ++-; CHECK-NEXT: ret; ++- %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %f1, float %f2) ++- ret <2 x bfloat> %val +++ +++; CHECK: cvt.rz.bf16x2.f32 +++ %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %f1, float %f2); +++ +++ret <2 x bfloat> %val ++ } ++ +++; CHECK-LABEL: cvt_rz_relu_bf16x2_f32 ++ define <2 x bfloat> @cvt_rz_relu_bf16x2_f32(float %f1, float %f2) { ++-; CHECK-LABEL: cvt_rz_relu_bf16x2_f32( ++-; CHECK: { ++-; CHECK-NEXT: .reg .b32 %r<2>; ++-; CHECK-NEXT: .reg .f32 %f<3>; ++-; CHECK-EMPTY: ++-; CHECK-NEXT: // %bb.0: ++-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_relu_bf16x2_f32_param_0]; ++-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rz_relu_bf16x2_f32_param_1]; ++-; CHECK-NEXT: cvt.rz.relu.bf16x2.f32 %r1, %f1, %f2; ++-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; ++-; CHECK-NEXT: ret; ++- %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %f1, float %f2) ++- ret <2 x bfloat> %val +++ +++; CHECK: cvt.rz.relu.bf16x2.f32 +++%val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %f1, float %f2); +++ +++ret <2 x bfloat> %val ++ } ++ ++ declare <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float, float) ++@@ -72,68 +43,40 @@ ++ declare <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float, float) ++ declare <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float, float) ++ +++; CHECK-LABEL: cvt_rn_f16x2_f32 ++ define <2 x half> @cvt_rn_f16x2_f32(float %f1, float %f2) { ++-; CHECK-LABEL: cvt_rn_f16x2_f32( ++-; CHECK: { ++-; CHECK-NEXT: .reg .b32 %r<2>; ++-; CHECK-NEXT: .reg .f32 %f<3>; ++-; CHECK-EMPTY: ++-; CHECK-NEXT: // %bb.0: ++-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_f16x2_f32_param_0]; ++-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rn_f16x2_f32_param_1]; ++-; CHECK-NEXT: cvt.rn.f16x2.f32 %r1, %f1, %f2; ++-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; ++-; CHECK-NEXT: ret; ++- %val = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %f1, float %f2) ++- ret <2 x half> %val +++ +++; CHECK: cvt.rn.f16x2.f32 +++ %val = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %f1, float %f2); +++ +++ret <2 x half> %val ++ } ++ +++; CHECK-LABEL: cvt_rn_relu_f16x2_f32 ++ define <2 x half> @cvt_rn_relu_f16x2_f32(float %f1, float %f2) { ++-; CHECK-LABEL: cvt_rn_relu_f16x2_f32( ++-; CHECK: { ++-; CHECK-NEXT: .reg .b32 %r<2>; ++-; CHECK-NEXT: .reg .f32 %f<3>; ++-; CHECK-EMPTY: ++-; CHECK-NEXT: // %bb.0: ++-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_relu_f16x2_f32_param_0]; ++-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rn_relu_f16x2_f32_param_1]; ++-; CHECK-NEXT: cvt.rn.relu.f16x2.f32 %r1, %f1, %f2; ++-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; ++-; CHECK-NEXT: ret; ++- %val = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %f1, float %f2) ++- ret <2 x half> %val +++ +++; CHECK: cvt.rn.relu.f16x2.f32 +++%val = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %f1, float %f2); +++ +++ret <2 x half> %val ++ } ++ +++; CHECK-LABEL: cvt_rz_f16x2_f32 ++ define <2 x half> @cvt_rz_f16x2_f32(float %f1, float %f2) { ++-; CHECK-LABEL: cvt_rz_f16x2_f32( ++-; CHECK: { ++-; CHECK-NEXT: .reg .b32 %r<2>; ++-; CHECK-NEXT: .reg .f32 %f<3>; ++-; CHECK-EMPTY: ++-; CHECK-NEXT: // %bb.0: ++-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_f16x2_f32_param_0]; ++-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rz_f16x2_f32_param_1]; ++-; CHECK-NEXT: cvt.rz.f16x2.f32 %r1, %f1, %f2; ++-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; ++-; CHECK-NEXT: ret; ++- %val = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %f1, float %f2) ++- ret <2 x half> %val +++ +++; CHECK: cvt.rz.f16x2.f32 +++ %val = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %f1, float %f2); +++ +++ret <2 x half> %val ++ } ++ +++; CHECK-LABEL: cvt_rz_relu_f16x2_f32 ++ define <2 x half> @cvt_rz_relu_f16x2_f32(float %f1, float %f2) { ++-; CHECK-LABEL: cvt_rz_relu_f16x2_f32( ++-; CHECK: { ++-; CHECK-NEXT: .reg .b32 %r<2>; ++-; CHECK-NEXT: .reg .f32 %f<3>; ++-; CHECK-EMPTY: ++-; CHECK-NEXT: // %bb.0: ++-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_relu_f16x2_f32_param_0]; ++-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rz_relu_f16x2_f32_param_1]; ++-; CHECK-NEXT: cvt.rz.relu.f16x2.f32 %r1, %f1, %f2; ++-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; ++-; CHECK-NEXT: ret; ++- %val = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %f1, float %f2) ++- ret <2 x half> %val +++ +++; CHECK: cvt.rz.relu.f16x2.f32 +++%val = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %f1, float %f2); +++ +++ret <2 x half> %val ++ } ++ ++ declare <2 x half> @llvm.nvvm.ff2f16x2.rn(float, float) ++@@ -141,64 +84,40 @@ ++ declare <2 x half> @llvm.nvvm.ff2f16x2.rz(float, float) ++ declare <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float, float) ++ +++; CHECK-LABEL: cvt_rn_bf16_f32 ++ define bfloat @cvt_rn_bf16_f32(float %f1) { ++-; CHECK-LABEL: cvt_rn_bf16_f32( ++-; CHECK: { ++-; CHECK-NEXT: .reg .b16 %rs<2>; ++-; CHECK-NEXT: .reg .f32 %f<2>; ++-; CHECK-EMPTY: ++-; CHECK-NEXT: // %bb.0: ++-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_bf16_f32_param_0]; ++-; CHECK-NEXT: cvt.rn.bf16.f32 %rs1, %f1; ++-; CHECK-NEXT: st.param.b16 [func_retval0], %rs1; ++-; CHECK-NEXT: ret; ++- %val = call bfloat @llvm.nvvm.f2bf16.rn(float %f1) ++- ret bfloat %val +++ +++; CHECK: cvt.rn.bf16.f32 +++ %val = call bfloat @llvm.nvvm.f2bf16.rn(float %f1); +++ +++ret bfloat %val ++ } ++ +++; CHECK-LABEL: cvt_rn_relu_bf16_f32 ++ define bfloat @cvt_rn_relu_bf16_f32(float %f1) { ++-; CHECK-LABEL: cvt_rn_relu_bf16_f32( ++-; CHECK: { ++-; CHECK-NEXT: .reg .b16 %rs<2>; ++-; CHECK-NEXT: .reg .f32 %f<2>; ++-; CHECK-EMPTY: ++-; CHECK-NEXT: // %bb.0: ++-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_relu_bf16_f32_param_0]; ++-; CHECK-NEXT: cvt.rn.relu.bf16.f32 %rs1, %f1; ++-; CHECK-NEXT: st.param.b16 [func_retval0], %rs1; ++-; CHECK-NEXT: ret; ++- %val = call bfloat @llvm.nvvm.f2bf16.rn.relu(float %f1) ++- ret bfloat %val +++ +++; CHECK: cvt.rn.relu.bf16.f32 +++%val = call bfloat @llvm.nvvm.f2bf16.rn.relu(float %f1); +++ +++ret bfloat %val ++ } ++ +++; CHECK-LABEL: cvt_rz_bf16_f32 ++ define bfloat @cvt_rz_bf16_f32(float %f1) { ++-; CHECK-LABEL: cvt_rz_bf16_f32( ++-; CHECK: { ++-; CHECK-NEXT: .reg .b16 %rs<2>; ++-; CHECK-NEXT: .reg .f32 %f<2>; ++-; CHECK-EMPTY: ++-; CHECK-NEXT: // %bb.0: ++-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_bf16_f32_param_0]; ++-; CHECK-NEXT: cvt.rz.bf16.f32 %rs1, %f1; ++-; CHECK-NEXT: st.param.b16 [func_retval0], %rs1; ++-; CHECK-NEXT: ret; ++- %val = call bfloat @llvm.nvvm.f2bf16.rz(float %f1) ++- ret bfloat %val +++ +++; CHECK: cvt.rz.bf16.f32 +++ %val = call bfloat @llvm.nvvm.f2bf16.rz(float %f1); +++ +++ret bfloat %val ++ } ++ +++; CHECK-LABEL: cvt_rz_relu_bf16_f32 ++ define bfloat @cvt_rz_relu_bf16_f32(float %f1) { ++-; CHECK-LABEL: cvt_rz_relu_bf16_f32( ++-; CHECK: { ++-; CHECK-NEXT: .reg .b16 %rs<2>; ++-; CHECK-NEXT: .reg .f32 %f<2>; ++-; CHECK-EMPTY: ++-; CHECK-NEXT: // %bb.0: ++-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_relu_bf16_f32_param_0]; ++-; CHECK-NEXT: cvt.rz.relu.bf16.f32 %rs1, %f1; ++-; CHECK-NEXT: st.param.b16 [func_retval0], %rs1; ++-; CHECK-NEXT: ret; ++- %val = call bfloat @llvm.nvvm.f2bf16.rz.relu(float %f1) ++- ret bfloat %val +++ +++; CHECK: cvt.rz.relu.bf16.f32 +++%val = call bfloat @llvm.nvvm.f2bf16.rz.relu(float %f1); +++ +++ret bfloat %val ++ } ++ ++ declare bfloat @llvm.nvvm.f2bf16.rn(float) ++@@ -206,58 +125,13 @@ ++ declare bfloat @llvm.nvvm.f2bf16.rz(float) ++ declare bfloat @llvm.nvvm.f2bf16.rz.relu(float) ++ +++; CHECK-LABEL: cvt_rna_tf32_f32 ++ define i32 @cvt_rna_tf32_f32(float %f1) { ++-; CHECK-LABEL: cvt_rna_tf32_f32( ++-; CHECK: { ++-; CHECK-NEXT: .reg .b32 %r<2>; ++-; CHECK-NEXT: .reg .f32 %f<2>; ++-; CHECK-EMPTY: ++-; CHECK-NEXT: // %bb.0: ++-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rna_tf32_f32_param_0]; ++-; CHECK-NEXT: cvt.rna.tf32.f32 %r1, %f1; ++-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; ++-; CHECK-NEXT: ret; ++- %val = call i32 @llvm.nvvm.f2tf32.rna(float %f1) ++- ret i32 %val ++-} ++ ++-declare i32 @llvm.nvvm.f2tf32.rna(float) +++; CHECK: cvt.rna.tf32.f32 +++ %val = call i32 @llvm.nvvm.f2tf32.rna(float %f1); ++ ++- ++-define <2 x bfloat> @fold_ff2bf16x2(float %a, float %b) { ++-; CHECK-LABEL: fold_ff2bf16x2( ++-; CHECK: { ++-; CHECK-NEXT: .reg .b32 %r<2>; ++-; CHECK-NEXT: .reg .f32 %f<3>; ++-; CHECK-EMPTY: ++-; CHECK-NEXT: // %bb.0: ++-; CHECK-NEXT: ld.param.f32 %f1, [fold_ff2bf16x2_param_0]; ++-; CHECK-NEXT: ld.param.f32 %f2, [fold_ff2bf16x2_param_1]; ++-; CHECK-NEXT: cvt.rn.bf16x2.f32 %r1, %f1, %f2; ++-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; ++-; CHECK-NEXT: ret; ++- %ah = fptrunc float %a to bfloat ++- %bh = fptrunc float %b to bfloat ++- %v0 = insertelement <2 x bfloat> poison, bfloat %ah, i64 0 ++- %v1 = insertelement <2 x bfloat> %v0, bfloat %bh, i64 1 ++- ret <2 x bfloat> %v1 ++-} ++- ++-define <2 x half> @fold_ff2f16x2(float %a, float %b) { ++-; CHECK-LABEL: fold_ff2f16x2( ++-; CHECK: { ++-; CHECK-NEXT: .reg .b32 %r<2>; ++-; CHECK-NEXT: .reg .f32 %f<3>; ++-; CHECK-EMPTY: ++-; CHECK-NEXT: // %bb.0: ++-; CHECK-NEXT: ld.param.f32 %f1, [fold_ff2f16x2_param_0]; ++-; CHECK-NEXT: ld.param.f32 %f2, [fold_ff2f16x2_param_1]; ++-; CHECK-NEXT: cvt.rn.f16x2.f32 %r1, %f1, %f2; ++-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; ++-; CHECK-NEXT: ret; ++- %ah = fptrunc float %a to half ++- %bh = fptrunc float %b to half ++- %v0 = insertelement <2 x half> poison, half %ah, i64 0 ++- %v1 = insertelement <2 x half> %v0, half %bh, i64 1 ++- ret <2 x half> %v1 +++ret i32 %val ++ } +++ +++declare i32 @llvm.nvvm.f2tf32.rna(float) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index ffa6478..ad42b42 100644 +index ad42b42..63355e1 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "97298853b4de70dbce9c0a140ac38e3ac179e02e" -- LLVM_SHA256 = "ac811cb61d281043c865c39260a5114a0e96d16ec0e4eb74a2516a24981b9064" -+ LLVM_COMMIT = "03730cdd3d10c5270fe436777a37d50b0838a3bf" -+ LLVM_SHA256 = "54d843249c75b200f7bf9b7947079fe16fa0b657c4aee4abdde4ac05a9cd5f84" +- LLVM_COMMIT = "03730cdd3d10c5270fe436777a37d50b0838a3bf" +- LLVM_SHA256 = "54d843249c75b200f7bf9b7947079fe16fa0b657c4aee4abdde4ac05a9cd5f84" ++ LLVM_COMMIT = "b3134fa2338388adf8cfb2d77339d0b042eab9f6" ++ LLVM_SHA256 = "b6024606092290b0838735c26ad1c5c239b3e931136b420af8680e3a1156e759" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 9da8f92769bba..9f016ff4cc057 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "c3f3c8e9ae90470e24891428072ff8cbc7445e95" - SHARDY_SHA256 = "92dd51c7de67e3b0bd09f388f1f71bebb44547a98b33ee9ece66eba0fac439d1" + SHARDY_COMMIT = "d0be1a187e24821862a556448ff91085316a2d80" + SHARDY_SHA256 = "993809c9779ade8d67cd7e9c84cef135d99f17b42fc670f5c557c703b72d18a6" tf_http_archive( name = "shardy", diff --git a/third_party/tsl/third_party/llvm/generated.patch b/third_party/tsl/third_party/llvm/generated.patch index 509398da979e8..0113efadd18e9 100644 --- a/third_party/tsl/third_party/llvm/generated.patch +++ b/third_party/tsl/third_party/llvm/generated.patch @@ -1 +1,704 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td ++++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +@@ -739,20 +739,6 @@ + def CVT_f16x2_e5m2x2 : CVT_f16x2_fp8<"e5m2">; + } + +-def fpround_oneuse : PatFrag<(ops node:$a), (fpround node:$a), [{ +- return N->hasOneUse(); +-}]>; +- +-def : Pat<(v2bf16 (build_vector (bf16 (fpround_oneuse Float32Regs:$a)), +- (bf16 (fpround_oneuse Float32Regs:$b)))), +- (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>, +- Requires<[hasPTX<70>, hasSM<80>, hasBF16Math]>; +- +-def : Pat<(v2f16 (build_vector (f16 (fpround_oneuse Float32Regs:$a)), +- (f16 (fpround_oneuse Float32Regs:$b)))), +- (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>, +- Requires<[hasPTX<70>, hasSM<80>, useFP16Math]>; +- + //----------------------------------- + // Selection instructions (selp) + //----------------------------------- +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll +--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll ++++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll +@@ -204,7 +204,7 @@ + ; + ; SM80-LABEL: test_faddx2( + ; SM80: { +-; SM80-NEXT: .reg .b16 %rs<5>; ++; SM80-NEXT: .reg .b16 %rs<7>; + ; SM80-NEXT: .reg .b32 %r<4>; + ; SM80-NEXT: .reg .f32 %f<7>; + ; SM80-EMPTY: +@@ -216,16 +216,18 @@ + ; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1; + ; SM80-NEXT: cvt.f32.bf16 %f2, %rs4; + ; SM80-NEXT: add.rn.f32 %f3, %f2, %f1; ++; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3; + ; SM80-NEXT: cvt.f32.bf16 %f4, %rs1; + ; SM80-NEXT: cvt.f32.bf16 %f5, %rs3; + ; SM80-NEXT: add.rn.f32 %f6, %f5, %f4; +-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; ++; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6; ++; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5}; + ; SM80-NEXT: st.param.b32 [func_retval0], %r3; + ; SM80-NEXT: ret; + ; + ; SM80-FTZ-LABEL: test_faddx2( + ; SM80-FTZ: { +-; SM80-FTZ-NEXT: .reg .b16 %rs<5>; ++; SM80-FTZ-NEXT: .reg .b16 %rs<7>; + ; SM80-FTZ-NEXT: .reg .b32 %r<4>; + ; SM80-FTZ-NEXT: .reg .f32 %f<7>; + ; SM80-FTZ-EMPTY: +@@ -237,10 +239,12 @@ + ; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4; + ; SM80-FTZ-NEXT: add.rn.ftz.f32 %f3, %f2, %f1; ++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3; + ; SM80-FTZ-NEXT: add.rn.ftz.f32 %f6, %f5, %f4; +-; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; ++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6; ++; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5}; + ; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3; + ; SM80-FTZ-NEXT: ret; + ; +@@ -307,7 +311,7 @@ + ; + ; SM80-LABEL: test_fsubx2( + ; SM80: { +-; SM80-NEXT: .reg .b16 %rs<5>; ++; SM80-NEXT: .reg .b16 %rs<7>; + ; SM80-NEXT: .reg .b32 %r<4>; + ; SM80-NEXT: .reg .f32 %f<7>; + ; SM80-EMPTY: +@@ -319,16 +323,18 @@ + ; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1; + ; SM80-NEXT: cvt.f32.bf16 %f2, %rs4; + ; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1; ++; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3; + ; SM80-NEXT: cvt.f32.bf16 %f4, %rs1; + ; SM80-NEXT: cvt.f32.bf16 %f5, %rs3; + ; SM80-NEXT: sub.rn.f32 %f6, %f5, %f4; +-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; ++; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6; ++; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5}; + ; SM80-NEXT: st.param.b32 [func_retval0], %r3; + ; SM80-NEXT: ret; + ; + ; SM80-FTZ-LABEL: test_fsubx2( + ; SM80-FTZ: { +-; SM80-FTZ-NEXT: .reg .b16 %rs<5>; ++; SM80-FTZ-NEXT: .reg .b16 %rs<7>; + ; SM80-FTZ-NEXT: .reg .b32 %r<4>; + ; SM80-FTZ-NEXT: .reg .f32 %f<7>; + ; SM80-FTZ-EMPTY: +@@ -340,10 +346,12 @@ + ; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4; + ; SM80-FTZ-NEXT: sub.rn.ftz.f32 %f3, %f2, %f1; ++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3; + ; SM80-FTZ-NEXT: sub.rn.ftz.f32 %f6, %f5, %f4; +-; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; ++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6; ++; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5}; + ; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3; + ; SM80-FTZ-NEXT: ret; + ; +@@ -410,7 +418,7 @@ + ; + ; SM80-LABEL: test_fmulx2( + ; SM80: { +-; SM80-NEXT: .reg .b16 %rs<5>; ++; SM80-NEXT: .reg .b16 %rs<7>; + ; SM80-NEXT: .reg .b32 %r<4>; + ; SM80-NEXT: .reg .f32 %f<7>; + ; SM80-EMPTY: +@@ -422,16 +430,18 @@ + ; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1; + ; SM80-NEXT: cvt.f32.bf16 %f2, %rs4; + ; SM80-NEXT: mul.rn.f32 %f3, %f2, %f1; ++; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3; + ; SM80-NEXT: cvt.f32.bf16 %f4, %rs1; + ; SM80-NEXT: cvt.f32.bf16 %f5, %rs3; + ; SM80-NEXT: mul.rn.f32 %f6, %f5, %f4; +-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; ++; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6; ++; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5}; + ; SM80-NEXT: st.param.b32 [func_retval0], %r3; + ; SM80-NEXT: ret; + ; + ; SM80-FTZ-LABEL: test_fmulx2( + ; SM80-FTZ: { +-; SM80-FTZ-NEXT: .reg .b16 %rs<5>; ++; SM80-FTZ-NEXT: .reg .b16 %rs<7>; + ; SM80-FTZ-NEXT: .reg .b32 %r<4>; + ; SM80-FTZ-NEXT: .reg .f32 %f<7>; + ; SM80-FTZ-EMPTY: +@@ -443,10 +453,12 @@ + ; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4; + ; SM80-FTZ-NEXT: mul.rn.ftz.f32 %f3, %f2, %f1; ++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3; + ; SM80-FTZ-NEXT: mul.rn.ftz.f32 %f6, %f5, %f4; +-; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; ++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6; ++; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5}; + ; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3; + ; SM80-FTZ-NEXT: ret; + ; +@@ -513,7 +525,7 @@ + ; + ; SM80-LABEL: test_fdiv( + ; SM80: { +-; SM80-NEXT: .reg .b16 %rs<5>; ++; SM80-NEXT: .reg .b16 %rs<7>; + ; SM80-NEXT: .reg .b32 %r<4>; + ; SM80-NEXT: .reg .f32 %f<7>; + ; SM80-EMPTY: +@@ -525,16 +537,18 @@ + ; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1; + ; SM80-NEXT: cvt.f32.bf16 %f2, %rs4; + ; SM80-NEXT: div.rn.f32 %f3, %f2, %f1; ++; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3; + ; SM80-NEXT: cvt.f32.bf16 %f4, %rs1; + ; SM80-NEXT: cvt.f32.bf16 %f5, %rs3; + ; SM80-NEXT: div.rn.f32 %f6, %f5, %f4; +-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; ++; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6; ++; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5}; + ; SM80-NEXT: st.param.b32 [func_retval0], %r3; + ; SM80-NEXT: ret; + ; + ; SM80-FTZ-LABEL: test_fdiv( + ; SM80-FTZ: { +-; SM80-FTZ-NEXT: .reg .b16 %rs<5>; ++; SM80-FTZ-NEXT: .reg .b16 %rs<7>; + ; SM80-FTZ-NEXT: .reg .b32 %r<4>; + ; SM80-FTZ-NEXT: .reg .f32 %f<7>; + ; SM80-FTZ-EMPTY: +@@ -546,16 +560,18 @@ + ; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4; + ; SM80-FTZ-NEXT: div.rn.ftz.f32 %f3, %f2, %f1; ++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1; + ; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3; + ; SM80-FTZ-NEXT: div.rn.ftz.f32 %f6, %f5, %f4; +-; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; ++; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6; ++; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5}; + ; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3; + ; SM80-FTZ-NEXT: ret; + ; + ; SM90-LABEL: test_fdiv( + ; SM90: { +-; SM90-NEXT: .reg .b16 %rs<5>; ++; SM90-NEXT: .reg .b16 %rs<7>; + ; SM90-NEXT: .reg .b32 %r<4>; + ; SM90-NEXT: .reg .f32 %f<7>; + ; SM90-EMPTY: +@@ -567,10 +583,12 @@ + ; SM90-NEXT: mov.b32 {%rs3, %rs4}, %r1; + ; SM90-NEXT: cvt.f32.bf16 %f2, %rs4; + ; SM90-NEXT: div.rn.f32 %f3, %f2, %f1; ++; SM90-NEXT: cvt.rn.bf16.f32 %rs5, %f3; + ; SM90-NEXT: cvt.f32.bf16 %f4, %rs1; + ; SM90-NEXT: cvt.f32.bf16 %f5, %rs3; + ; SM90-NEXT: div.rn.f32 %f6, %f5, %f4; +-; SM90-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3; ++; SM90-NEXT: cvt.rn.bf16.f32 %rs6, %f6; ++; SM90-NEXT: mov.b32 %r3, {%rs6, %rs5}; + ; SM90-NEXT: st.param.b32 [func_retval0], %r3; + ; SM90-NEXT: ret; + %r = fdiv <2 x bfloat> %a, %b +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll +--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll ++++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll +@@ -13,7 +13,9 @@ + ; CHECK-DAG: cvt.f32.bf16 [[AF1:%f[0-9]+]], [[A1]]; + ; CHECK-DAG: sin.approx.f32 [[RF0:%f[0-9]+]], [[AF0]]; + ; CHECK-DAG: sin.approx.f32 [[RF1:%f[0-9]+]], [[AF1]]; +-; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]] ++; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]]; ++; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]]; ++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; + define <2 x bfloat> @test_sin(<2 x bfloat> %a) #0 #1 { +@@ -28,7 +30,9 @@ + ; CHECK-DAG: cvt.f32.bf16 [[AF1:%f[0-9]+]], [[A1]]; + ; CHECK-DAG: cos.approx.f32 [[RF0:%f[0-9]+]], [[AF0]]; + ; CHECK-DAG: cos.approx.f32 [[RF1:%f[0-9]+]], [[AF1]]; +-; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]] ++; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]]; ++; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]]; ++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; + define <2 x bfloat> @test_cos(<2 x bfloat> %a) #0 #1 { +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll +--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll ++++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll +@@ -26,7 +26,9 @@ + ; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]] + ; SM80-DAG: add.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], 0f3F800000; + ; SM80-DAG: add.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], 0f40000000; +-; SM80-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR0]], [[FR1]]; ++; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]] ++; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]] ++; SM80-DAG: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; + ; CHECK-NEXT: st.param.b32 [func_retval0], [[R]]; + ; CHECK-NEXT: ret; +@@ -66,7 +68,9 @@ + ; SM80-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]]; + ; SM80-DAG: sub.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]]; + ; SM80-DAG: sub.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]]; +-; SM80-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR0]], [[FR1]]; ++; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]]; ++; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]]; ++; SM80: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}; + + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; +@@ -89,7 +93,9 @@ + ; SM80-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]]; + ; SM80-DAG: mul.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]]; + ; SM80-DAG: mul.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]]; +-; SM80-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR0]], [[FR1]]; ++; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]]; ++; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]]; ++; SM80: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}; + + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; +@@ -110,7 +116,9 @@ + ; CHECK-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]]; + ; CHECK-DAG: div.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]]; + ; CHECK-DAG: div.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]]; +-; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR0]], [[FR1]]; ++; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]]; ++; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]]; ++; CHECK-NEXT: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; CHECK-NEXT: st.param.b32 [func_retval0], [[R]]; + ; CHECK-NEXT: ret; + +@@ -279,7 +287,9 @@ + + ; CHECK-LABEL: test_fptrunc_2xfloat( + ; CHECK: ld.param.v2.f32 {[[A0:%f[0-9]+]], [[A1:%f[0-9]+]]}, [test_fptrunc_2xfloat_param_0]; +-; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[A0]], [[A1]]; ++; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[A0]]; ++; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[A1]]; ++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; + define <2 x bfloat> @test_fptrunc_2xfloat(<2 x float> %a) #0 { +@@ -349,7 +359,9 @@ + ; CHECK-DAG: cvt.f32.bf16 [[AF1:%f[0-9]+]], [[A1]]; + ; CHECK-DAG: sqrt.rn.f32 [[RF0:%f[0-9]+]], [[AF0]]; + ; CHECK-DAG: sqrt.rn.f32 [[RF1:%f[0-9]+]], [[AF1]]; +-; CHECK-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]]; ++; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]]; ++; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]]; ++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; + define <2 x bfloat> @test_sqrt(<2 x bfloat> %a) #0 { +@@ -424,7 +436,9 @@ + ; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]]; + ; SM80-DAG: cvt.rmi.f32.f32 [[RF0:%f[0-9]+]], [[FA0]]; + ; SM80-DAG: cvt.rmi.f32.f32 [[RF1:%f[0-9]+]], [[FA1]]; +-; SM80: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]]; ++; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]]; ++; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]]; ++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; + define <2 x bfloat> @test_floor(<2 x bfloat> %a) #0 { +@@ -441,7 +455,9 @@ + ; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]]; + ; SM80-DAG: cvt.rpi.f32.f32 [[RF0:%f[0-9]+]], [[FA0]]; + ; SM80-DAG: cvt.rpi.f32.f32 [[RF1:%f[0-9]+]], [[FA1]]; +-; SM80: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]]; ++; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]]; ++; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]]; ++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; + define <2 x bfloat> @test_ceil(<2 x bfloat> %a) #0 { +@@ -454,7 +470,7 @@ + ; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]]; + ; SM90: cvt.rzi.bf16.bf16 [[R1:%rs[0-9]+]], [[A1]]; + ; SM90: cvt.rzi.bf16.bf16 [[R0:%rs[0-9]+]], [[A0]]; +-; SM90: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} ++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; + define <2 x bfloat> @test_trunc(<2 x bfloat> %a) #0 { +@@ -467,7 +483,7 @@ + ; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]]; + ; SM90: cvt.rni.bf16.bf16 [[R1:%rs[0-9]+]], [[A1]]; + ; SM90: cvt.rni.bf16.bf16 [[R0:%rs[0-9]+]], [[A0]]; +-; SM90: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} ++; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} + ; CHECK: st.param.b32 [func_retval0], [[R]]; + ; CHECK: ret; + define <2 x bfloat> @test_rint(<2 x bfloat> %a) #0 { +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/convert-sm80.ll b/llvm/test/CodeGen/NVPTX/convert-sm80.ll +--- a/llvm/test/CodeGen/NVPTX/convert-sm80.ll ++++ b/llvm/test/CodeGen/NVPTX/convert-sm80.ll +@@ -1,70 +1,41 @@ +-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 + ; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | FileCheck %s + ; RUN: %if ptxas-11.0 %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | %ptxas-verify -arch=sm_80 %} + + ++; CHECK-LABEL: cvt_rn_bf16x2_f32 + define <2 x bfloat> @cvt_rn_bf16x2_f32(float %f1, float %f2) { +-; CHECK-LABEL: cvt_rn_bf16x2_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_bf16x2_f32_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rn_bf16x2_f32_param_1]; +-; CHECK-NEXT: cvt.rn.bf16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %f1, float %f2) +- ret <2 x bfloat> %val ++ ++; CHECK: cvt.rn.bf16x2.f32 ++ %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %f1, float %f2); ++ ++ret <2 x bfloat> %val + } + ++; CHECK-LABEL: cvt_rn_relu_bf16x2_f32 + define <2 x bfloat> @cvt_rn_relu_bf16x2_f32(float %f1, float %f2) { +-; CHECK-LABEL: cvt_rn_relu_bf16x2_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_relu_bf16x2_f32_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rn_relu_bf16x2_f32_param_1]; +-; CHECK-NEXT: cvt.rn.relu.bf16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %f1, float %f2) +- ret <2 x bfloat> %val ++ ++; CHECK: cvt.rn.relu.bf16x2.f32 ++%val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %f1, float %f2); ++ ++ret <2 x bfloat> %val + } + ++; CHECK-LABEL: cvt_rz_bf16x2_f32 + define <2 x bfloat> @cvt_rz_bf16x2_f32(float %f1, float %f2) { +-; CHECK-LABEL: cvt_rz_bf16x2_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_bf16x2_f32_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rz_bf16x2_f32_param_1]; +-; CHECK-NEXT: cvt.rz.bf16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %f1, float %f2) +- ret <2 x bfloat> %val ++ ++; CHECK: cvt.rz.bf16x2.f32 ++ %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %f1, float %f2); ++ ++ret <2 x bfloat> %val + } + ++; CHECK-LABEL: cvt_rz_relu_bf16x2_f32 + define <2 x bfloat> @cvt_rz_relu_bf16x2_f32(float %f1, float %f2) { +-; CHECK-LABEL: cvt_rz_relu_bf16x2_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_relu_bf16x2_f32_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rz_relu_bf16x2_f32_param_1]; +-; CHECK-NEXT: cvt.rz.relu.bf16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %f1, float %f2) +- ret <2 x bfloat> %val ++ ++; CHECK: cvt.rz.relu.bf16x2.f32 ++%val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %f1, float %f2); ++ ++ret <2 x bfloat> %val + } + + declare <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float, float) +@@ -72,68 +43,40 @@ + declare <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float, float) + declare <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float, float) + ++; CHECK-LABEL: cvt_rn_f16x2_f32 + define <2 x half> @cvt_rn_f16x2_f32(float %f1, float %f2) { +-; CHECK-LABEL: cvt_rn_f16x2_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_f16x2_f32_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rn_f16x2_f32_param_1]; +-; CHECK-NEXT: cvt.rn.f16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %val = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %f1, float %f2) +- ret <2 x half> %val ++ ++; CHECK: cvt.rn.f16x2.f32 ++ %val = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %f1, float %f2); ++ ++ret <2 x half> %val + } + ++; CHECK-LABEL: cvt_rn_relu_f16x2_f32 + define <2 x half> @cvt_rn_relu_f16x2_f32(float %f1, float %f2) { +-; CHECK-LABEL: cvt_rn_relu_f16x2_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_relu_f16x2_f32_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rn_relu_f16x2_f32_param_1]; +-; CHECK-NEXT: cvt.rn.relu.f16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %val = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %f1, float %f2) +- ret <2 x half> %val ++ ++; CHECK: cvt.rn.relu.f16x2.f32 ++%val = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %f1, float %f2); ++ ++ret <2 x half> %val + } + ++; CHECK-LABEL: cvt_rz_f16x2_f32 + define <2 x half> @cvt_rz_f16x2_f32(float %f1, float %f2) { +-; CHECK-LABEL: cvt_rz_f16x2_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_f16x2_f32_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rz_f16x2_f32_param_1]; +-; CHECK-NEXT: cvt.rz.f16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %val = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %f1, float %f2) +- ret <2 x half> %val ++ ++; CHECK: cvt.rz.f16x2.f32 ++ %val = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %f1, float %f2); ++ ++ret <2 x half> %val + } + ++; CHECK-LABEL: cvt_rz_relu_f16x2_f32 + define <2 x half> @cvt_rz_relu_f16x2_f32(float %f1, float %f2) { +-; CHECK-LABEL: cvt_rz_relu_f16x2_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_relu_f16x2_f32_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [cvt_rz_relu_f16x2_f32_param_1]; +-; CHECK-NEXT: cvt.rz.relu.f16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %val = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %f1, float %f2) +- ret <2 x half> %val ++ ++; CHECK: cvt.rz.relu.f16x2.f32 ++%val = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %f1, float %f2); ++ ++ret <2 x half> %val + } + + declare <2 x half> @llvm.nvvm.ff2f16x2.rn(float, float) +@@ -141,64 +84,40 @@ + declare <2 x half> @llvm.nvvm.ff2f16x2.rz(float, float) + declare <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float, float) + ++; CHECK-LABEL: cvt_rn_bf16_f32 + define bfloat @cvt_rn_bf16_f32(float %f1) { +-; CHECK-LABEL: cvt_rn_bf16_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b16 %rs<2>; +-; CHECK-NEXT: .reg .f32 %f<2>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_bf16_f32_param_0]; +-; CHECK-NEXT: cvt.rn.bf16.f32 %rs1, %f1; +-; CHECK-NEXT: st.param.b16 [func_retval0], %rs1; +-; CHECK-NEXT: ret; +- %val = call bfloat @llvm.nvvm.f2bf16.rn(float %f1) +- ret bfloat %val ++ ++; CHECK: cvt.rn.bf16.f32 ++ %val = call bfloat @llvm.nvvm.f2bf16.rn(float %f1); ++ ++ret bfloat %val + } + ++; CHECK-LABEL: cvt_rn_relu_bf16_f32 + define bfloat @cvt_rn_relu_bf16_f32(float %f1) { +-; CHECK-LABEL: cvt_rn_relu_bf16_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b16 %rs<2>; +-; CHECK-NEXT: .reg .f32 %f<2>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_relu_bf16_f32_param_0]; +-; CHECK-NEXT: cvt.rn.relu.bf16.f32 %rs1, %f1; +-; CHECK-NEXT: st.param.b16 [func_retval0], %rs1; +-; CHECK-NEXT: ret; +- %val = call bfloat @llvm.nvvm.f2bf16.rn.relu(float %f1) +- ret bfloat %val ++ ++; CHECK: cvt.rn.relu.bf16.f32 ++%val = call bfloat @llvm.nvvm.f2bf16.rn.relu(float %f1); ++ ++ret bfloat %val + } + ++; CHECK-LABEL: cvt_rz_bf16_f32 + define bfloat @cvt_rz_bf16_f32(float %f1) { +-; CHECK-LABEL: cvt_rz_bf16_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b16 %rs<2>; +-; CHECK-NEXT: .reg .f32 %f<2>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_bf16_f32_param_0]; +-; CHECK-NEXT: cvt.rz.bf16.f32 %rs1, %f1; +-; CHECK-NEXT: st.param.b16 [func_retval0], %rs1; +-; CHECK-NEXT: ret; +- %val = call bfloat @llvm.nvvm.f2bf16.rz(float %f1) +- ret bfloat %val ++ ++; CHECK: cvt.rz.bf16.f32 ++ %val = call bfloat @llvm.nvvm.f2bf16.rz(float %f1); ++ ++ret bfloat %val + } + ++; CHECK-LABEL: cvt_rz_relu_bf16_f32 + define bfloat @cvt_rz_relu_bf16_f32(float %f1) { +-; CHECK-LABEL: cvt_rz_relu_bf16_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b16 %rs<2>; +-; CHECK-NEXT: .reg .f32 %f<2>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_relu_bf16_f32_param_0]; +-; CHECK-NEXT: cvt.rz.relu.bf16.f32 %rs1, %f1; +-; CHECK-NEXT: st.param.b16 [func_retval0], %rs1; +-; CHECK-NEXT: ret; +- %val = call bfloat @llvm.nvvm.f2bf16.rz.relu(float %f1) +- ret bfloat %val ++ ++; CHECK: cvt.rz.relu.bf16.f32 ++%val = call bfloat @llvm.nvvm.f2bf16.rz.relu(float %f1); ++ ++ret bfloat %val + } + + declare bfloat @llvm.nvvm.f2bf16.rn(float) +@@ -206,58 +125,13 @@ + declare bfloat @llvm.nvvm.f2bf16.rz(float) + declare bfloat @llvm.nvvm.f2bf16.rz.relu(float) + ++; CHECK-LABEL: cvt_rna_tf32_f32 + define i32 @cvt_rna_tf32_f32(float %f1) { +-; CHECK-LABEL: cvt_rna_tf32_f32( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<2>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [cvt_rna_tf32_f32_param_0]; +-; CHECK-NEXT: cvt.rna.tf32.f32 %r1, %f1; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %val = call i32 @llvm.nvvm.f2tf32.rna(float %f1) +- ret i32 %val +-} + +-declare i32 @llvm.nvvm.f2tf32.rna(float) ++; CHECK: cvt.rna.tf32.f32 ++ %val = call i32 @llvm.nvvm.f2tf32.rna(float %f1); + +- +-define <2 x bfloat> @fold_ff2bf16x2(float %a, float %b) { +-; CHECK-LABEL: fold_ff2bf16x2( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [fold_ff2bf16x2_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [fold_ff2bf16x2_param_1]; +-; CHECK-NEXT: cvt.rn.bf16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %ah = fptrunc float %a to bfloat +- %bh = fptrunc float %b to bfloat +- %v0 = insertelement <2 x bfloat> poison, bfloat %ah, i64 0 +- %v1 = insertelement <2 x bfloat> %v0, bfloat %bh, i64 1 +- ret <2 x bfloat> %v1 +-} +- +-define <2 x half> @fold_ff2f16x2(float %a, float %b) { +-; CHECK-LABEL: fold_ff2f16x2( +-; CHECK: { +-; CHECK-NEXT: .reg .b32 %r<2>; +-; CHECK-NEXT: .reg .f32 %f<3>; +-; CHECK-EMPTY: +-; CHECK-NEXT: // %bb.0: +-; CHECK-NEXT: ld.param.f32 %f1, [fold_ff2f16x2_param_0]; +-; CHECK-NEXT: ld.param.f32 %f2, [fold_ff2f16x2_param_1]; +-; CHECK-NEXT: cvt.rn.f16x2.f32 %r1, %f1, %f2; +-; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +-; CHECK-NEXT: ret; +- %ah = fptrunc float %a to half +- %bh = fptrunc float %b to half +- %v0 = insertelement <2 x half> poison, half %ah, i64 0 +- %v1 = insertelement <2 x half> %v0, half %bh, i64 1 +- ret <2 x half> %v1 ++ret i32 %val + } ++ ++declare i32 @llvm.nvvm.f2tf32.rna(float) diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index ad42b42a95f79..63355e174793a 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "03730cdd3d10c5270fe436777a37d50b0838a3bf" - LLVM_SHA256 = "54d843249c75b200f7bf9b7947079fe16fa0b657c4aee4abdde4ac05a9cd5f84" + LLVM_COMMIT = "b3134fa2338388adf8cfb2d77339d0b042eab9f6" + LLVM_SHA256 = "b6024606092290b0838735c26ad1c5c239b3e931136b420af8680e3a1156e759" tf_http_archive( name = name,