diff --git a/gcc/config/rs6000/float16.cc b/gcc/config/rs6000/float16.cc index fa196486a635..0d606609dab3 100644 --- a/gcc/config/rs6000/float16.cc +++ b/gcc/config/rs6000/float16.cc @@ -60,13 +60,13 @@ SFmode. */ void -bfloat16_expand_binary_op (enum rtx_code icode, - rtx op0, - rtx op1, - rtx op2, - rtx tmp0, - rtx tmp1, - rtx tmp2) +bfloat16_binary_op_as_v4sf (enum rtx_code icode, + rtx op0, + rtx op1, + rtx op2, + rtx tmp0, + rtx tmp1, + rtx tmp2) { if (GET_CODE (tmp0) == SCRATCH) tmp0 = gen_reg_rtx (V4SFmode); diff --git a/gcc/config/rs6000/float16.md b/gcc/config/rs6000/float16.md index db757888dbf5..2bc552d344b3 100644 --- a/gcc/config/rs6000/float16.md +++ b/gcc/config/rs6000/float16.md @@ -40,19 +40,6 @@ ;; convert to/from _Float16 (HFmode) via DFmode. (define_mode_iterator fp16_float_convert [TF KF IF SD DD TD]) -;; Code iterator giving the basic operations for bfloat16 floating point -;; operations. -(define_code_iterator fp16_binary_op [plus div minus mult smax smin]) - -;; Code attribute that gives the standard name for the bfloat16 -;; operations done via V4SF vector. -(define_code_attr fp16_binary_name [(plus "add") - (div "div") - (minus "sub") - (mult "mul") - (smax "smax") - (smin "smin")]) - ;; Mode attribute giving the instruction to convert the even ;; V8HFmode or V8BFmode elements to V4SFmode (define_mode_attr cvt_fp16_to_v4sf_insn [(BF "xvcvbf16spn") @@ -398,147 +385,198 @@ (set_attr "type" "vecperm")]) -;; Bfloat16 floating point operations. We convert the 16-bit scalar to a -;; V4SF vector, do the operation, and then convert the value back to -;; 16-bit format. We only care about the 2nd element that the scalar -;; value in it. For plus, minus, and mult the other 3 elements can be -;; 0. This means we can combine a load (which sets the other bits to -;; 0) with the conversion to vector. For divide, the divisor must not -;; be 0, so we use a splat operation to guarantee that we are not -;; dividing by 0. +;; Optimize __bfloat16 binary operations. Unlike _Float16 where we +;; have instructions to convert between HFmode and SFmode as scalar +;; values, with BFmode, we only have vector conversions. Thus to do: +;; +;; __bfloat16 a, b, c; +;; a = b + c; +;; +;; the GCC compiler would normally generate: +;; +;; lxsihzx 0,4,2 // load __bfloat16 value b +;; lxsihzx 12,5,2 // load __bfloat16 value c +;; xxsldwi 0,0,0,1 // shift b into bits 16..31 +;; xxsldwi 12,12,12,1 // shift c into bits 16..31 +;; xvcvbf16spn 0,0 // vector convert b into V4SFmode +;; xvcvbf16spn 12,12 // vector convert c into V4SFmode +;; xscvspdpn 0,0 // convert b into SFmode scalar +;; xscvspdpn 12,12 // convert c into SFmode scalar +;; fadds 0,0,12 // add b+c +;; xscvdpspn 0,0 // convert b+c into SFmode memory format +;; xvcvspbf16 0,0 // convert b+c into BFmode memory format +;; stxsihx 0,3,2 // store b+c +;; +;; Using the following combiner patterns, the code generated would now +;; be: +;; +;; lxsihzx 12,4,2 // load __bfloat16 value b +;; lxsihzx 0,5,2 // load __bfloat16 value c +;; xxspltw 12,12,1 // shift b into bits 16..31 +;; xxspltw 0,0,1 // shift c into bits 16..31 +;; xvcvbf16spn 12,12 // vector convert b into V4SFmode +;; xvcvbf16spn 0,0 // vector convert c into V4SFmode +;; xvaddsp 0,0,12 // vector b+c in V4SFmode +;; xvcvspbf16 0,0 // convert b+c into BFmode memory format +;; stxsihx 0,3,2 // store b+c +;; +;; We cannot just define insns like 'addbf3' to keep the operation as +;; BFmode because GCC will not generate these patterns unless the user +;; uses -Ofast. Without -Ofast, it will always convert BFmode into +;; SFmode. -(define_insn_and_split "bf3" - [(set (match_operand:BF 0 "vsx_register_operand" "=wa,wa,wa") - (fp16_binary_op:BF - (match_operand:BF 1 "vsx_register_operand" "wa,wa,wa") - (match_operand:BF 2 "fp16_reg_or_constant_operand" "wa,j,eP"))) - (clobber (match_scratch:V4SF 3 "=&wa,&wa,&wa")) - (clobber (match_scratch:V4SF 4 "=&wa,&wa,&wa")) - (clobber (match_scratch:V4SF 5 "=&wa,&wa,&wa"))] - "TARGET_BFLOAT16_HW" - "#" - "&& 1" - [(pc)] -{ - bfloat16_expand_binary_op (, - operands[0], - operands[1], - operands[2], - operands[3], - operands[4], - operands[5]); - DONE; -} - [(set_attr "type" "vecperm") - (set_attr "length" "24,24,32")]) - -(define_insn_and_split "*bf3_internal1" +(define_insn_and_split "*bfloat16_binary_op_internal1" [(set (match_operand:SF 0 "vsx_register_operand" "=wa") - (fp16_binary_op:SF - (float_extend:SF - (match_operand:BF 1 "vsx_register_operand" "wa")) - (float_extend:SF - (match_operand:BF 2 "vsx_register_operand" "wa")))) - (clobber (match_scratch:V4SF 3 "=&wa")) + (match_operator:SF 1 "bfloat16_binary_operator" + [(float_extend:SF + (match_operand:BF 2 "vsx_register_operand" "wa")) + (float_extend:SF + (match_operand:BF 3 "vsx_register_operand" "wa"))])) (clobber (match_scratch:V4SF 4 "=&wa")) - (clobber (match_scratch:V4SF 5 "=&wa"))] + (clobber (match_scratch:V4SF 5 "=&wa")) + (clobber (match_scratch:V4SF 6 "=&wa"))] "TARGET_BFLOAT16_HW" "#" "&& 1" [(pc)] { - bfloat16_expand_binary_op (, - operands[0], - operands[1], - operands[2], - operands[3], - operands[4], - operands[5]); + bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]), + operands[0], + operands[2], + operands[3], + operands[4], + operands[5], + operands[6]); DONE; -} - [(set_attr "type" "vecperm") - (set_attr "length" "24")]) +}) -(define_insn_and_split "*bf3_internal2" +(define_insn_and_split "*bfloat16_binary_op_internal2" [(set (match_operand:BF 0 "vsx_register_operand" "=wa") (float_truncate:BF - (fp16_binary_op:SF - (float_extend:SF - (match_operand:BF 1 "vsx_register_operand" "wa")) - (float_extend:SF - (match_operand:BF 2 "vsx_register_operand" "wa"))))) - (clobber (match_scratch:V4SF 3 "=&wa")) + (match_operator:SF 1 "bfloat16_binary_operator" + [(float_extend:SF + (match_operand:BF 2 "vsx_register_operand" "wa")) + (float_extend:SF + (match_operand:BF 3 "vsx_register_operand" "wa"))]))) (clobber (match_scratch:V4SF 4 "=&wa")) - (clobber (match_scratch:V4SF 5 "=&wa"))] + (clobber (match_scratch:V4SF 5 "=&wa")) + (clobber (match_scratch:V4SF 6 "=&wa"))] "TARGET_BFLOAT16_HW" "#" "&& 1" [(pc)] { - bfloat16_expand_binary_op (, - operands[0], - operands[1], - operands[2], - operands[3], - operands[4], - operands[5]); + bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]), + operands[0], + operands[2], + operands[3], + operands[4], + operands[5], + operands[6]); DONE; -} - [(set_attr "type" "vecperm") - (set_attr "length" "24")]) +}) -(define_insn_and_split "*bf3_internal3" +(define_insn_and_split "*bfloat16_binary_op_internal3" [(set (match_operand:SF 0 "vsx_register_operand" "=wa,wa,wa") - (fp16_binary_op:SF - (float_extend:SF - (match_operand:BF 1 "vsx_register_operand" "wa,wa,wa")) - (match_operand:SF 2 "input_operand" "wa,j,eP"))) - (clobber (match_scratch:V4SF 3 "=&wa,&wa,&wa")) + (match_operator:SF 1 "bfloat16_binary_operator" + [(float_extend:SF + (match_operand:BF 2 "vsx_register_operand" "wa,wa,wa")) + (match_operand:SF 3 "input_operand" "wa,j,eP")])) (clobber (match_scratch:V4SF 4 "=&wa,&wa,&wa")) - (clobber (match_scratch:V4SF 5 "=&wa,&wa,&wa"))] + (clobber (match_scratch:V4SF 5 "=&wa,&wa,&wa")) + (clobber (match_scratch:V4SF 6 "=&wa,&wa,&wa"))] "TARGET_BFLOAT16_HW" "#" "&& 1" [(pc)] { - bfloat16_expand_binary_op (, - operands[0], - operands[1], - operands[2], - operands[3], - operands[4], - operands[5]); + bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]), + operands[0], + operands[2], + operands[3], + operands[4], + operands[5], + operands[6]); DONE; } - [(set_attr "type" "vecperm") - (set_attr "length" "24,24,32")]) + [(set_attr "type" "vecperm")]) -(define_insn_and_split "*bf3_internal4" - [(set (match_operand:BF 0 "vsx_register_operand" "=wa,wa,wa") +(define_insn_and_split "*bfloat16_binary_op_internal4" + [(set (match_operand:BF 0 "vsx_register_operand" "=wa,&wa,&wa") (float_truncate:BF - (fp16_binary_op:SF - (float_extend:SF - (match_operand:BF 1 "vsx_register_operand" "wa,wa,wa")) - (match_operand:SF 2 "input_operand" "wa,j,eP")))) - (clobber (match_scratch:V4SF 3 "=&wa,&wa,&wa")) + (match_operator:SF 1 "bfloat16_binary_operator" + [(float_extend:SF + (match_operand:BF 2 "vsx_register_operand" "wa,wa,wa")) + (match_operand:SF 3 "input_operand" "wa,j,eP")]))) (clobber (match_scratch:V4SF 4 "=&wa,&wa,&wa")) - (clobber (match_scratch:V4SF 5 "=&wa,&wa,&wa"))] + (clobber (match_scratch:V4SF 5 "=&wa,&wa,&wa")) + (clobber (match_scratch:V4SF 6 "=&wa,&wa,&wa"))] "TARGET_BFLOAT16_HW" "#" "&& 1" [(pc)] { - bfloat16_expand_binary_op (, - operands[0], - operands[1], - operands[2], - operands[3], - operands[4], - operands[5]); + bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]), + operands[0], + operands[2], + operands[3], + operands[4], + operands[5], + operands[6]); DONE; } - [(set_attr "type" "vecperm") - (set_attr "length" "24,24,32")]) + [(set_attr "type" "vecperm")]) + +(define_insn_and_split "*bfloat16_binary_op_internal5" + [(set (match_operand:SF 0 "vsx_register_operand" "=wa") + (match_operator:SF 1 "bfloat16_binary_operator" + [(match_operand:SF 2 "vsx_register_operand" "wa") + (float_extend:SF + (match_operand:BF 3 "vsx_register_operand" "wa"))])) + (clobber (match_scratch:V4SF 4 "=&wa")) + (clobber (match_scratch:V4SF 5 "=&wa")) + (clobber (match_scratch:V4SF 6 "=&wa"))] + "TARGET_BFLOAT16_HW" + "#" + "&& 1" + [(pc)] +{ + bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]), + operands[0], + operands[2], + operands[3], + operands[4], + operands[5], + operands[6]); + DONE; +} + [(set_attr "type" "vecperm")]) + +(define_insn_and_split "*bfloat16_binary_op_internal6" + [(set (match_operand:BF 0 "vsx_register_operand" "=wa") + (float_truncate:BF + (match_operator:SF 1 "bfloat16_binary_operator" + [(match_operand:SF 3 "vsx_register_operand" "wa") + (float_extend:SF + (match_operand:BF 2 "vsx_register_operand" "wa"))]))) + (clobber (match_scratch:V4SF 4 "=&wa")) + (clobber (match_scratch:V4SF 5 "=&wa")) + (clobber (match_scratch:V4SF 6 "=&wa"))] + "TARGET_BFLOAT16_HW" + "#" + "&& 1" + [(pc)] +{ + bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]), + operands[0], + operands[2], + operands[3], + operands[4], + operands[5], + operands[6]); + DONE; +} + [(set_attr "type" "vecperm")]) ;; Duplicate a HF/BF value so it can be used for xvcvhpspn/xvcvbf16spn. ;; Because xvcvhpspn/xvcvbf16spn only uses the even elements, we can diff --git a/gcc/config/rs6000/predicates.md b/gcc/config/rs6000/predicates.md index 4f7d1dd0be5a..2de33f7f32a6 100644 --- a/gcc/config/rs6000/predicates.md +++ b/gcc/config/rs6000/predicates.md @@ -2202,3 +2202,11 @@ return false; }) + +;; Match binary operators where we convert a BFmode operand into a +;; SFmode operand so that we can optimize the BFmode operation to do +;; the operation in vector mode rather than convverting the BFmode to a +;; V8BFmode vector, converting that V8BFmode vector to V4SFmode, and +;; then converting the V4SFmode element to SFmode scalar. +(define_predicate "bfloat16_binary_operator" + (match_code "plus,minus,mult,smax,smin")) diff --git a/gcc/config/rs6000/rs6000-protos.h b/gcc/config/rs6000/rs6000-protos.h index 23b7f29cbece..063f74f6e3f6 100644 --- a/gcc/config/rs6000/rs6000-protos.h +++ b/gcc/config/rs6000/rs6000-protos.h @@ -260,8 +260,8 @@ extern unsigned constant_generates_xxspltiw (vec_const_128bit_type *); extern unsigned constant_generates_xxspltidp (vec_const_128bit_type *); /* From float16.cc. */ -extern void bfloat16_expand_binary_op (enum rtx_code, rtx, rtx, rtx, - rtx, rtx, rtx); +extern void bfloat16_binary_op_as_v4sf (enum rtx_code, rtx, rtx, rtx, + rtx, rtx, rtx); #endif /* RTX_CODE */ #ifdef TREE_CODE