blob: 290e4d148a3f93a6b787df2883180b919b483584 [file] [log] [blame]
;; rewrites for shifts and rotates: `ishl, `ushr`, `sshr`, `rotl, `rotr`
;; x>>0 == x<<0 == x rotr 0 == x rotl 0 == x.
(rule (simplify (ishl ty
x
(iconst_u ty 0)))
(subsume x))
(rule (simplify (ushr ty
x
(iconst_u ty 0)))
(subsume x))
(rule (simplify (sshr ty
x
(iconst_u ty 0)))
(subsume x))
(rule (simplify (rotr ty
x
(iconst_u ty 0)))
(subsume x))
(rule (simplify (rotl ty
x
(iconst_u ty 0)))
(subsume x))
;; `(x >> k) << k` is the same as masking off the bottom `k` bits (regardless if
;; this is a signed or unsigned shift right).
(rule (simplify (ishl (fits_in_64 ty)
(ushr ty x (iconst _ k))
(iconst _ k)))
(let ((mask Imm64 (imm64_shl ty (imm64 0xFFFF_FFFF_FFFF_FFFF) k)))
(band ty x (iconst ty mask))))
(rule (simplify (ishl (fits_in_64 ty)
(sshr ty x (iconst _ k))
(iconst _ k)))
(let ((mask Imm64 (imm64_shl ty (imm64 0xFFFF_FFFF_FFFF_FFFF) k)))
(band ty x (iconst ty mask))))
;; For unsigned shifts, `(x << k) >> k` is the same as masking out the top
;; `k` bits. A similar rule is valid for vectors but this `iconst` mask only
;; works for scalar integers.
(rule (simplify (ushr (fits_in_64 (ty_int ty))
(ishl ty x (iconst _ k))
(iconst _ k)))
(band ty x (iconst ty (imm64_ushr ty (imm64 (ty_mask ty)) k))))
;; For signed shifts, `(x << k) >> k` does sign-extension from `n` bits to
;; `n+k` bits. In the special case where `x` is the result of either `sextend`
;; or `uextend` from `n` bits to `n+k` bits, we can implement this using
;; `sextend`.
(rule (simplify (sshr wide
(ishl wide
(uextend wide x @ (value_type narrow))
(iconst_u _ shift_u64))
(iconst_u _ shift_u64)))
(if-let $true (u64_eq shift_u64 (u64_sub (ty_bits_u64 wide) (ty_bits_u64 narrow))))
(sextend wide x))
;; If `k` is smaller than the difference in bit widths of the two types, then
;; the intermediate sign bit comes from the extend op, so the final result is
;; the same as the original extend op.
(rule (simplify (sshr wide
(ishl wide
x @ (uextend wide (value_type narrow))
(iconst_u _ shift_u64))
(iconst_u _ shift_u64)))
(if-let $true (u64_lt shift_u64 (u64_sub (ty_bits_u64 wide) (ty_bits_u64 narrow))))
x)
;; If the original extend op was `sextend`, then both of the above cases say
;; the result should also be `sextend`.
(rule (simplify (sshr wide
(ishl wide
x @ (sextend wide (value_type narrow))
(iconst_u _ shift_u64))
(iconst_u _ shift_u64)))
(if-let $true (u64_le shift_u64 (u64_sub (ty_bits_u64 wide) (ty_bits_u64 narrow))))
x)
;; (x << N) >> N == x as T_SMALL as T_LARGE
;; if N == bytesizeof(T_LARGE) - bytesizeof(T_SMALL)
;;
;; Note that the shift is required to be >0 to ensure this doesn't accidentally
;; try to `ireduce` a type to itself, which isn't a valid use of `ireduce`.
(rule (simplify (sshr ty (ishl ty x (iconst _ shift)) (iconst _ shift)))
(if-let (u64_from_imm64 (u64_nonzero shift_u64)) shift)
(if-let ty_small (shift_amt_to_type (u64_sub (ty_bits ty) shift_u64)))
(sextend ty (ireduce ty_small x)))
(rule (simplify (ushr ty (ishl ty x (iconst _ shift)) (iconst _ shift)))
(if-let (u64_from_imm64 (u64_nonzero shift_u64)) shift)
(if-let ty_small (shift_amt_to_type (u64_sub (ty_bits ty) shift_u64)))
(uextend ty (ireduce ty_small x)))
(decl pure partial shift_amt_to_type (u64) Type)
(rule (shift_amt_to_type 8) $I8)
(rule (shift_amt_to_type 16) $I16)
(rule (shift_amt_to_type 32) $I32)
;; ineg(ushr(x, k)) == sshr(x, k) when k == ty_bits - 1.
(rule (simplify (ineg ty (ushr ty x sconst @ (iconst_u ty shift_amt))))
(if-let $true (u64_eq shift_amt (ty_shift_mask ty)))
(sshr ty x sconst))
;; Shifts and rotates allow a different type for the shift amount, so we
;; can remove any extend/reduce operations on the shift amount.
;;
;; (op x (ireduce y)) == (op x y)
;; (op x (uextend y)) == (op x y)
;; (op x (sextend y)) == (op x y)
;;
;; where `op` is one of ishl, ushr, sshr, rotl, rotr
;;
;; TODO: This rule is restricted to <=64 bits for ireduce since the x86
;; backend doesn't support SIMD shifts with 128-bit shift amounts.
(rule (simplify (ishl ty x (ireduce _ y @ (value_type (fits_in_64 _))))) (ishl ty x y))
(rule (simplify (ishl ty x (uextend _ y))) (ishl ty x y))
(rule (simplify (ishl ty x (sextend _ y))) (ishl ty x y))
(rule (simplify (ushr ty x (ireduce _ y @ (value_type (fits_in_64 _))))) (ushr ty x y))
(rule (simplify (ushr ty x (uextend _ y))) (ushr ty x y))
(rule (simplify (ushr ty x (sextend _ y))) (ushr ty x y))
(rule (simplify (sshr ty x (ireduce _ y @ (value_type (fits_in_64 _))))) (sshr ty x y))
(rule (simplify (sshr ty x (uextend _ y))) (sshr ty x y))
(rule (simplify (sshr ty x (sextend _ y))) (sshr ty x y))
(rule (simplify (rotr ty x (ireduce _ y @ (value_type (fits_in_64 _))))) (rotr ty x y))
(rule (simplify (rotr ty x (uextend _ y))) (rotr ty x y))
(rule (simplify (rotr ty x (sextend _ y))) (rotr ty x y))
(rule (simplify (rotl ty x (ireduce _ y @ (value_type (fits_in_64 _))))) (rotl ty x y))
(rule (simplify (rotl ty x (uextend _ y))) (rotl ty x y))
(rule (simplify (rotl ty x (sextend _ y))) (rotl ty x y))
;; Remove iconcat from the shift amount input. This is correct even if the
;; the iconcat is i8 type, since it can represent the largest shift amount
;; for i128 types.
;;
;; (op x (iconcat y1 y2)) == (op x y1)
;;
;; where `op` is one of ishl, ushr, sshr, rotl, rotr
(rule (simplify (ishl ty x (iconcat _ y _))) (ishl ty x y))
(rule (simplify (ushr ty x (iconcat _ y _))) (ushr ty x y))
(rule (simplify (sshr ty x (iconcat _ y _))) (sshr ty x y))
(rule (simplify (rotr ty x (iconcat _ y _))) (rotr ty x y))
(rule (simplify (rotl ty x (iconcat _ y _))) (rotl ty x y))
;; Try to combine the shift amount from multiple consecutive shifts
;; This only works if the shift amount remains smaller than the bit
;; width of the type.
;;
;; (ishl (ishl x k1) k2) == (ishl x (add k1 k2)) if shift_mask(k1) + shift_mask(k2) < ty_bits
;; (ushr (ushr x k1) k2) == (ushr x (add k1 k2)) if shift_mask(k1) + shift_mask(k2) < ty_bits
;; (sshr (sshr x k1) k2) == (sshr x (add k1 k2)) if shift_mask(k1) + shift_mask(k2) < ty_bits
(rule (simplify (ishl ty
(ishl ty x (iconst_u kty k1))
(iconst_u _ k2)))
(if-let shift_amt (u64_add
(u64_and k1 (ty_shift_mask ty))
(u64_and k2 (ty_shift_mask ty))))
(if-let $true (u64_lt shift_amt (ty_bits_u64 (lane_type ty))))
(ishl ty x (iconst_u kty shift_amt)))
(rule (simplify (ushr ty
(ushr ty x (iconst_u kty k1))
(iconst_u _ k2)))
(if-let shift_amt (u64_add
(u64_and k1 (ty_shift_mask ty))
(u64_and k2 (ty_shift_mask ty))))
(if-let $true (u64_lt shift_amt (ty_bits_u64 (lane_type ty))))
(ushr ty x (iconst_u kty shift_amt)))
(rule (simplify (sshr ty
(sshr ty x (iconst_u kty k1))
(iconst_u _ k2)))
(if-let shift_amt (u64_add
(u64_and k1 (ty_shift_mask ty))
(u64_and k2 (ty_shift_mask ty))))
(if-let $true (u64_lt shift_amt (ty_bits_u64 (lane_type ty))))
(sshr ty x (iconst_u kty shift_amt)))
;; Simliarly, if the shift amount overflows the type, then we can turn
;; it into a 0
;;
;; (ishl (ishl x k1) k2) == 0 if shift_mask(k1) + shift_mask(k2) >= ty_bits
;; (ushr (ushr x k1) k2) == 0 if shift_mask(k1) + shift_mask(k2) >= ty_bits
(rule (simplify (ishl ty
(ishl ty x (iconst_u _ k1))
(iconst_u _ k2)))
(if-let shift_amt (u64_add
(u64_and k1 (ty_shift_mask ty))
(u64_and k2 (ty_shift_mask ty))))
(if-let $true (u64_le (ty_bits_u64 ty) shift_amt))
(subsume (iconst_u ty 0)))
(rule (simplify (ushr ty
(ushr ty x (iconst_u _ k1))
(iconst_u _ k2)))
(if-let shift_amt (u64_add
(u64_and k1 (ty_shift_mask ty))
(u64_and k2 (ty_shift_mask ty))))
(if-let $true (u64_le (ty_bits_u64 ty) shift_amt))
(subsume (iconst_u ty 0)))
;; (rotl (rotr x y) y) == x
;; (rotr (rotl x y) y) == x
(rule (simplify (rotl ty (rotr ty x y) y)) (subsume x))
(rule (simplify (rotr ty (rotl ty x y) y)) (subsume x))
;; Emits an iadd for two values. If they have different types
;; then the smaller type is zero extended to the larger type.
(decl iadd_uextend (Value Value) Value)
(rule 1 (iadd_uextend x @ (value_type ty) y @ (value_type ty))
(iadd ty x y))
(rule 2 (iadd_uextend x @ (value_type x_ty) y @ (value_type y_ty))
(if-let $true (u64_lt (ty_bits_u64 x_ty) (ty_bits_u64 y_ty)))
(iadd y_ty (uextend y_ty x) y))
(rule 3 (iadd_uextend x @ (value_type x_ty) y @ (value_type y_ty))
(if-let $true (u64_lt (ty_bits_u64 y_ty) (ty_bits_u64 x_ty)))
(iadd x_ty x (uextend x_ty y)))
;; Emits an isub for two values. If they have different types
;; then the smaller type is zero extended to the larger type.
(decl isub_uextend (Value Value) Value)
(rule 1 (isub_uextend x @ (value_type ty) y @ (value_type ty))
(isub ty x y))
(rule 2 (isub_uextend x @ (value_type x_ty) y @ (value_type y_ty))
(if-let $true (u64_lt (ty_bits_u64 x_ty) (ty_bits_u64 y_ty)))
(isub y_ty (uextend y_ty x) y))
(rule 3 (isub_uextend x @ (value_type x_ty) y @ (value_type y_ty))
(if-let $true (u64_lt (ty_bits_u64 y_ty) (ty_bits_u64 x_ty)))
(isub x_ty x (uextend x_ty y)))
;; Try to group constants together so that other cprop rules can optimize them.
;;
;; (rotr (rotr x y) z) == (rotr x (iadd y z))
;; (rotl (rotl x y) z) == (rotl x (iadd y z))
;; (rotr (rotl x y) z) == (rotr x (isub y z))
;; (rotl (rotr x y) z) == (rotl x (isub y z))
;;
;; if x or z are constants
(rule (simplify (rotl ty (rotl ty x y @ (iconst _ _)) z)) (rotl ty x (iadd_uextend y z)))
(rule (simplify (rotl ty (rotl ty x y) z @ (iconst _ _))) (rotl ty x (iadd_uextend y z)))
(rule (simplify (rotr ty (rotr ty x y @ (iconst _ _)) z)) (rotr ty x (iadd_uextend y z)))
(rule (simplify (rotr ty (rotr ty x y) z @ (iconst _ _))) (rotr ty x (iadd_uextend y z)))
(rule (simplify (rotr ty (rotl ty x y @ (iconst _ _)) z)) (rotl ty x (isub_uextend y z)))
(rule (simplify (rotr ty (rotl ty x y) z @ (iconst _ _))) (rotl ty x (isub_uextend y z)))
(rule (simplify (rotl ty (rotr ty x y @ (iconst _ _)) z)) (rotr ty x (isub_uextend y z)))
(rule (simplify (rotl ty (rotr ty x y) z @ (iconst _ _))) (rotr ty x (isub_uextend y z)))
;; Similarly to the rules above, if y and z have the same type, we should emit
;; an iadd or isub instead. In some backends this is cheaper than a rotate.
;;
;; If they have different types we end up in a situation where we have to insert
;; and additional extend and that transformation is not universally beneficial.
;;
;; (rotr (rotr x y) z) == (rotr x (iadd y z))
;; (rotl (rotl x y) z) == (rotl x (iadd y z))
;; (rotr (rotl x y) z) == (rotl x (isub y z))
;; (rotl (rotr x y) z) == (rotr x (isub y z))
(rule (simplify (rotr ty (rotr ty x y @ (value_type kty)) z @ (value_type kty)))
(rotr ty x (iadd_uextend y z)))
(rule (simplify (rotl ty (rotl ty x y @ (value_type kty)) z @ (value_type kty)))
(rotl ty x (iadd_uextend y z)))
(rule (simplify (rotr ty (rotl ty x y @ (value_type kty)) z @ (value_type kty)))
(rotl ty x (isub_uextend y z)))
(rule (simplify (rotl ty (rotr ty x y @ (value_type kty)) z @ (value_type kty)))
(rotr ty x (isub_uextend y z)))
;; Convert shifts into rotates. We always normalize into a rotate left.
;;
;; (bor (ishl x k1) (ushr x k2)) == (rotl x k1) if k2 == ty_bits - k1
;; (bor (ushr x k2) (ishl x k1)) == (rotl x k1) if k2 == ty_bits - k1
;;
;; TODO: This rule is restricted to scalars since no backend currently
;; supports SIMD rotates.
(rule (simplify (bor (ty_int ty)
(ishl ty x k @ (iconst _ (u64_from_imm64 k1)))
(ushr ty x (iconst _ (u64_from_imm64 k2)))))
(if-let $true (u64_eq k2 (u64_sub (ty_bits_u64 (lane_type ty)) k1)))
(rotl ty x k))
(rule (simplify (bor (ty_int ty)
(ushr ty x (iconst _ (u64_from_imm64 k2)))
(ishl ty x k @ (iconst _ (u64_from_imm64 k1)))))
(if-let $true (u64_eq k2 (u64_sub (ty_bits_u64 (lane_type ty)) k1)))
(rotl ty x k))
;; Normalize the shift amount. Some rules can't fire unless the shift amount
;; is normalized. This also helps us materialize fewer and smaller constants.
;;
;; (op x k) == (op x (and k (ty_shift_mask ty)))
;;
;; where `op` is one of ishl, ushr, sshr, rotl, rotr
(rule (simplify (ishl ty x (iconst_u kty k)))
(if-let $false (u64_eq k (u64_and k (ty_shift_mask ty))))
(ishl ty x (iconst_u kty (u64_and k (ty_shift_mask ty)))))
(rule (simplify (ushr ty x (iconst_u kty k)))
(if-let $false (u64_eq k (u64_and k (ty_shift_mask ty))))
(ushr ty x (iconst_u kty (u64_and k (ty_shift_mask ty)))))
(rule (simplify (sshr ty x (iconst_u kty k)))
(if-let $false (u64_eq k (u64_and k (ty_shift_mask ty))))
(sshr ty x (iconst_u kty (u64_and k (ty_shift_mask ty)))))
(rule (simplify (rotr ty x (iconst_u kty k)))
(if-let $false (u64_eq k (u64_and k (ty_shift_mask ty))))
(rotr ty x (iconst_u kty (u64_and k (ty_shift_mask ty)))))
(rule (simplify (rotl ty x (iconst_u kty k)))
(if-let $false (u64_eq k (u64_and k (ty_shift_mask ty))))
(rotl ty x (iconst_u kty (u64_and k (ty_shift_mask ty)))))