Do bfloat16 binary operations as combiner patterns.

2025-10-07  Michael Meissner  <meissner@linux.ibm.com>

gcc/

	* config/rs6000/float16.cc (bfloat16_binary_op_as_v4sf): Rename from
	bfloat16_expand_binary_op.
	* config/rs6000/float16.md (fp16_binary_op): Delete.
	(fp16_binary_name): Likewise.
	(<fp16_binary_name>bf): Likewise.
	(bfloat16_binary_op_internal1): Update and extend bfloat16 combiner
	patterns.
	(bfloat16_binary_op_internal2): Likewise.
	(bfloat16_binary_op_internal3): Likewise.
	(bfloat16_binary_op_internal4): Likewise.
	(bfloat16_binary_op_internal5): Likewise.
	(bfloat16_binary_op_internal6): Likewise.
	* config/rs6000/predicates.md (bfloat16_binary_operator): New predicate.
	* config/rs6000/rs6000-protos.h (bfloat16_binary_op_as_v4sf): Rename
	from bfloat16_expand_binary_op.
This commit is contained in:
Michael Meissner 2025-10-07 12:43:47 -04:00
parent 61d6ec5e8b
commit 04a8c5de44
4 changed files with 170 additions and 124 deletions

View File

@ -60,13 +60,13 @@
SFmode. */
void
bfloat16_expand_binary_op (enum rtx_code icode,
rtx op0,
rtx op1,
rtx op2,
rtx tmp0,
rtx tmp1,
rtx tmp2)
bfloat16_binary_op_as_v4sf (enum rtx_code icode,
rtx op0,
rtx op1,
rtx op2,
rtx tmp0,
rtx tmp1,
rtx tmp2)
{
if (GET_CODE (tmp0) == SCRATCH)
tmp0 = gen_reg_rtx (V4SFmode);

View File

@ -40,19 +40,6 @@
;; convert to/from _Float16 (HFmode) via DFmode.
(define_mode_iterator fp16_float_convert [TF KF IF SD DD TD])
;; Code iterator giving the basic operations for bfloat16 floating point
;; operations.
(define_code_iterator fp16_binary_op [plus div minus mult smax smin])
;; Code attribute that gives the standard name for the bfloat16
;; operations done via V4SF vector.
(define_code_attr fp16_binary_name [(plus "add")
(div "div")
(minus "sub")
(mult "mul")
(smax "smax")
(smin "smin")])
;; Mode attribute giving the instruction to convert the even
;; V8HFmode or V8BFmode elements to V4SFmode
(define_mode_attr cvt_fp16_to_v4sf_insn [(BF "xvcvbf16spn")
@ -398,147 +385,198 @@
(set_attr "type" "vecperm")])
;; Bfloat16 floating point operations. We convert the 16-bit scalar to a
;; V4SF vector, do the operation, and then convert the value back to
;; 16-bit format. We only care about the 2nd element that the scalar
;; value in it. For plus, minus, and mult the other 3 elements can be
;; 0. This means we can combine a load (which sets the other bits to
;; 0) with the conversion to vector. For divide, the divisor must not
;; be 0, so we use a splat operation to guarantee that we are not
;; dividing by 0.
;; Optimize __bfloat16 binary operations. Unlike _Float16 where we
;; have instructions to convert between HFmode and SFmode as scalar
;; values, with BFmode, we only have vector conversions. Thus to do:
;;
;; __bfloat16 a, b, c;
;; a = b + c;
;;
;; the GCC compiler would normally generate:
;;
;; lxsihzx 0,4,2 // load __bfloat16 value b
;; lxsihzx 12,5,2 // load __bfloat16 value c
;; xxsldwi 0,0,0,1 // shift b into bits 16..31
;; xxsldwi 12,12,12,1 // shift c into bits 16..31
;; xvcvbf16spn 0,0 // vector convert b into V4SFmode
;; xvcvbf16spn 12,12 // vector convert c into V4SFmode
;; xscvspdpn 0,0 // convert b into SFmode scalar
;; xscvspdpn 12,12 // convert c into SFmode scalar
;; fadds 0,0,12 // add b+c
;; xscvdpspn 0,0 // convert b+c into SFmode memory format
;; xvcvspbf16 0,0 // convert b+c into BFmode memory format
;; stxsihx 0,3,2 // store b+c
;;
;; Using the following combiner patterns, the code generated would now
;; be:
;;
;; lxsihzx 12,4,2 // load __bfloat16 value b
;; lxsihzx 0,5,2 // load __bfloat16 value c
;; xxspltw 12,12,1 // shift b into bits 16..31
;; xxspltw 0,0,1 // shift c into bits 16..31
;; xvcvbf16spn 12,12 // vector convert b into V4SFmode
;; xvcvbf16spn 0,0 // vector convert c into V4SFmode
;; xvaddsp 0,0,12 // vector b+c in V4SFmode
;; xvcvspbf16 0,0 // convert b+c into BFmode memory format
;; stxsihx 0,3,2 // store b+c
;;
;; We cannot just define insns like 'addbf3' to keep the operation as
;; BFmode because GCC will not generate these patterns unless the user
;; uses -Ofast. Without -Ofast, it will always convert BFmode into
;; SFmode.
(define_insn_and_split "<fp16_binary_name>bf3"
[(set (match_operand:BF 0 "vsx_register_operand" "=wa,wa,wa")
(fp16_binary_op:BF
(match_operand:BF 1 "vsx_register_operand" "wa,wa,wa")
(match_operand:BF 2 "fp16_reg_or_constant_operand" "wa,j,eP")))
(clobber (match_scratch:V4SF 3 "=&wa,&wa,&wa"))
(clobber (match_scratch:V4SF 4 "=&wa,&wa,&wa"))
(clobber (match_scratch:V4SF 5 "=&wa,&wa,&wa"))]
"TARGET_BFLOAT16_HW"
"#"
"&& 1"
[(pc)]
{
bfloat16_expand_binary_op (<CODE>,
operands[0],
operands[1],
operands[2],
operands[3],
operands[4],
operands[5]);
DONE;
}
[(set_attr "type" "vecperm")
(set_attr "length" "24,24,32")])
(define_insn_and_split "*<fp16_binary_name>bf3_internal1"
(define_insn_and_split "*bfloat16_binary_op_internal1"
[(set (match_operand:SF 0 "vsx_register_operand" "=wa")
(fp16_binary_op:SF
(float_extend:SF
(match_operand:BF 1 "vsx_register_operand" "wa"))
(float_extend:SF
(match_operand:BF 2 "vsx_register_operand" "wa"))))
(clobber (match_scratch:V4SF 3 "=&wa"))
(match_operator:SF 1 "bfloat16_binary_operator"
[(float_extend:SF
(match_operand:BF 2 "vsx_register_operand" "wa"))
(float_extend:SF
(match_operand:BF 3 "vsx_register_operand" "wa"))]))
(clobber (match_scratch:V4SF 4 "=&wa"))
(clobber (match_scratch:V4SF 5 "=&wa"))]
(clobber (match_scratch:V4SF 5 "=&wa"))
(clobber (match_scratch:V4SF 6 "=&wa"))]
"TARGET_BFLOAT16_HW"
"#"
"&& 1"
[(pc)]
{
bfloat16_expand_binary_op (<CODE>,
operands[0],
operands[1],
operands[2],
operands[3],
operands[4],
operands[5]);
bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]),
operands[0],
operands[2],
operands[3],
operands[4],
operands[5],
operands[6]);
DONE;
}
[(set_attr "type" "vecperm")
(set_attr "length" "24")])
})
(define_insn_and_split "*<fp16_binary_name>bf3_internal2"
(define_insn_and_split "*bfloat16_binary_op_internal2"
[(set (match_operand:BF 0 "vsx_register_operand" "=wa")
(float_truncate:BF
(fp16_binary_op:SF
(float_extend:SF
(match_operand:BF 1 "vsx_register_operand" "wa"))
(float_extend:SF
(match_operand:BF 2 "vsx_register_operand" "wa")))))
(clobber (match_scratch:V4SF 3 "=&wa"))
(match_operator:SF 1 "bfloat16_binary_operator"
[(float_extend:SF
(match_operand:BF 2 "vsx_register_operand" "wa"))
(float_extend:SF
(match_operand:BF 3 "vsx_register_operand" "wa"))])))
(clobber (match_scratch:V4SF 4 "=&wa"))
(clobber (match_scratch:V4SF 5 "=&wa"))]
(clobber (match_scratch:V4SF 5 "=&wa"))
(clobber (match_scratch:V4SF 6 "=&wa"))]
"TARGET_BFLOAT16_HW"
"#"
"&& 1"
[(pc)]
{
bfloat16_expand_binary_op (<CODE>,
operands[0],
operands[1],
operands[2],
operands[3],
operands[4],
operands[5]);
bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]),
operands[0],
operands[2],
operands[3],
operands[4],
operands[5],
operands[6]);
DONE;
}
[(set_attr "type" "vecperm")
(set_attr "length" "24")])
})
(define_insn_and_split "*<fp16_binary_name>bf3_internal3"
(define_insn_and_split "*bfloat16_binary_op_internal3"
[(set (match_operand:SF 0 "vsx_register_operand" "=wa,wa,wa")
(fp16_binary_op:SF
(float_extend:SF
(match_operand:BF 1 "vsx_register_operand" "wa,wa,wa"))
(match_operand:SF 2 "input_operand" "wa,j,eP")))
(clobber (match_scratch:V4SF 3 "=&wa,&wa,&wa"))
(match_operator:SF 1 "bfloat16_binary_operator"
[(float_extend:SF
(match_operand:BF 2 "vsx_register_operand" "wa,wa,wa"))
(match_operand:SF 3 "input_operand" "wa,j,eP")]))
(clobber (match_scratch:V4SF 4 "=&wa,&wa,&wa"))
(clobber (match_scratch:V4SF 5 "=&wa,&wa,&wa"))]
(clobber (match_scratch:V4SF 5 "=&wa,&wa,&wa"))
(clobber (match_scratch:V4SF 6 "=&wa,&wa,&wa"))]
"TARGET_BFLOAT16_HW"
"#"
"&& 1"
[(pc)]
{
bfloat16_expand_binary_op (<CODE>,
operands[0],
operands[1],
operands[2],
operands[3],
operands[4],
operands[5]);
bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]),
operands[0],
operands[2],
operands[3],
operands[4],
operands[5],
operands[6]);
DONE;
}
[(set_attr "type" "vecperm")
(set_attr "length" "24,24,32")])
[(set_attr "type" "vecperm")])
(define_insn_and_split "*<fp16_binary_name>bf3_internal4"
[(set (match_operand:BF 0 "vsx_register_operand" "=wa,wa,wa")
(define_insn_and_split "*bfloat16_binary_op_internal4"
[(set (match_operand:BF 0 "vsx_register_operand" "=wa,&wa,&wa")
(float_truncate:BF
(fp16_binary_op:SF
(float_extend:SF
(match_operand:BF 1 "vsx_register_operand" "wa,wa,wa"))
(match_operand:SF 2 "input_operand" "wa,j,eP"))))
(clobber (match_scratch:V4SF 3 "=&wa,&wa,&wa"))
(match_operator:SF 1 "bfloat16_binary_operator"
[(float_extend:SF
(match_operand:BF 2 "vsx_register_operand" "wa,wa,wa"))
(match_operand:SF 3 "input_operand" "wa,j,eP")])))
(clobber (match_scratch:V4SF 4 "=&wa,&wa,&wa"))
(clobber (match_scratch:V4SF 5 "=&wa,&wa,&wa"))]
(clobber (match_scratch:V4SF 5 "=&wa,&wa,&wa"))
(clobber (match_scratch:V4SF 6 "=&wa,&wa,&wa"))]
"TARGET_BFLOAT16_HW"
"#"
"&& 1"
[(pc)]
{
bfloat16_expand_binary_op (<CODE>,
operands[0],
operands[1],
operands[2],
operands[3],
operands[4],
operands[5]);
bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]),
operands[0],
operands[2],
operands[3],
operands[4],
operands[5],
operands[6]);
DONE;
}
[(set_attr "type" "vecperm")
(set_attr "length" "24,24,32")])
[(set_attr "type" "vecperm")])
(define_insn_and_split "*bfloat16_binary_op_internal5"
[(set (match_operand:SF 0 "vsx_register_operand" "=wa")
(match_operator:SF 1 "bfloat16_binary_operator"
[(match_operand:SF 2 "vsx_register_operand" "wa")
(float_extend:SF
(match_operand:BF 3 "vsx_register_operand" "wa"))]))
(clobber (match_scratch:V4SF 4 "=&wa"))
(clobber (match_scratch:V4SF 5 "=&wa"))
(clobber (match_scratch:V4SF 6 "=&wa"))]
"TARGET_BFLOAT16_HW"
"#"
"&& 1"
[(pc)]
{
bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]),
operands[0],
operands[2],
operands[3],
operands[4],
operands[5],
operands[6]);
DONE;
}
[(set_attr "type" "vecperm")])
(define_insn_and_split "*bfloat16_binary_op_internal6"
[(set (match_operand:BF 0 "vsx_register_operand" "=wa")
(float_truncate:BF
(match_operator:SF 1 "bfloat16_binary_operator"
[(match_operand:SF 3 "vsx_register_operand" "wa")
(float_extend:SF
(match_operand:BF 2 "vsx_register_operand" "wa"))])))
(clobber (match_scratch:V4SF 4 "=&wa"))
(clobber (match_scratch:V4SF 5 "=&wa"))
(clobber (match_scratch:V4SF 6 "=&wa"))]
"TARGET_BFLOAT16_HW"
"#"
"&& 1"
[(pc)]
{
bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]),
operands[0],
operands[2],
operands[3],
operands[4],
operands[5],
operands[6]);
DONE;
}
[(set_attr "type" "vecperm")])
;; Duplicate a HF/BF value so it can be used for xvcvhpspn/xvcvbf16spn.
;; Because xvcvhpspn/xvcvbf16spn only uses the even elements, we can

View File

@ -2202,3 +2202,11 @@
return false;
})
;; Match binary operators where we convert a BFmode operand into a
;; SFmode operand so that we can optimize the BFmode operation to do
;; the operation in vector mode rather than convverting the BFmode to a
;; V8BFmode vector, converting that V8BFmode vector to V4SFmode, and
;; then converting the V4SFmode element to SFmode scalar.
(define_predicate "bfloat16_binary_operator"
(match_code "plus,minus,mult,smax,smin"))

View File

@ -260,8 +260,8 @@ extern unsigned constant_generates_xxspltiw (vec_const_128bit_type *);
extern unsigned constant_generates_xxspltidp (vec_const_128bit_type *);
/* From float16.cc. */
extern void bfloat16_expand_binary_op (enum rtx_code, rtx, rtx, rtx,
rtx, rtx, rtx);
extern void bfloat16_binary_op_as_v4sf (enum rtx_code, rtx, rtx, rtx,
rtx, rtx, rtx);
#endif /* RTX_CODE */
#ifdef TREE_CODE