blob: eae54f1f4e3d7779aab97aafdbf53126438c4853 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the legalization pattern that converts complex operations into
// equivalent real value operations.
include "mlir/IR/OpBase.td"
include "mlir/Dialect/Func/IR/FuncOps.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
//===----------------------------------------------------------------------===//
// Binary op patterns.
//===----------------------------------------------------------------------===//
// Add and subtraction are elementwise and can be distributed across the real
// and imaginary components.
foreach elementwiseOp = [HLO_AddOp, HLO_SubtractOp] in
def : Pat<(elementwiseOp HLO_ComplexTensor:$lhs,
HLO_ComplexTensor:$rhs),
(HLO_ComplexOp
(elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs)),
(elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs)))>;
// Complex multiplication results in a cross product multiplication between the
// real and imaginary components such that:
// result.real = lhs.real * rhs.real - lhs.imag * rhs.imag
// result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs,
HLO_ComplexTensor:$rhs),
(HLO_ComplexOp
(HLO_SubtractOp
(HLO_MulOp
(HLO_RealOp:$lhs_real $lhs),
(HLO_RealOp:$rhs_real $rhs)),
(HLO_MulOp
(HLO_ImagOp:$lhs_imag $lhs),
(HLO_ImagOp:$rhs_imag $rhs))),
(HLO_AddOp
(HLO_MulOp $lhs_real, $rhs_imag),
(HLO_MulOp $lhs_imag, $rhs_real)))>;
// Division is performed by normalizing the denominator by multiplying by the
// conjugate of the rhs.
// numerator = lhs * conj(rhs)
// denominator = rhs * conj(rhs)
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs),
(HLO_ComplexOp
(HLO_DivOp
(HLO_RealOp (HLO_MulOp:$num $lhs,
(HLO_ComplexOp:$conj
(HLO_RealOp $rhs),
(HLO_NegOp (HLO_ImagOp $rhs))))),
(HLO_AddOp:$den
(HLO_MulOp (HLO_RealOp $rhs), (HLO_RealOp $rhs)),
(HLO_MulOp (HLO_ImagOp $rhs), (HLO_ImagOp $rhs)))),
(HLO_DivOp (HLO_ImagOp $num), $den))>;
// Absolute value is evaluated as:
// result = sqrt(val.real * val.real + val.imag * val.imag)
def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val),
(HLO_SqrtOp
(HLO_AddOp
(HLO_MulOp (HLO_RealOp:$real $val), $real),
(HLO_MulOp (HLO_ImagOp:$imag $val), $imag)))>;
// Can deconstruct sin(a + ib) as follows:
// sin(a) * cosh(b) + icos(a) * sinh(b)
// sinh(b) = (e^x - e^-x) / 2
// cosh(b) = (e^x + e^-x) / 2
def : Pat<(HLO_SineOp HLO_ComplexTensor:$val),
(HLO_ComplexOp
(HLO_DivOp
(HLO_MulOp
(HLO_SineOp (HLO_RealOp:$real $val)),
(HLO_AddOp
(HLO_ExpOp:$exp (HLO_ImagOp:$imag $val)),
(HLO_ExpOp:$nexp (HLO_NegOp $imag)))),
(HLO_ConstantOp : $two (ConstantSplat<"2.0"> $real))),
(HLO_DivOp
(HLO_MulOp
(HLO_CosineOp $real),
(HLO_SubtractOp $exp, $nexp)), $two))>;
// Can deconstruct cos(a + ib) as follows:
// cos(a) * cosh(b) - isin(a) * sinh(b)
// sinh(b) = (e^x - e^-x) / 2
// cosh(b) = (e^x + e^-x) / 2
def : Pat<(HLO_CosineOp HLO_ComplexTensor:$val),
(HLO_ComplexOp
(HLO_DivOp
(HLO_MulOp
(HLO_CosineOp (HLO_RealOp:$real $val)),
(HLO_AddOp
(HLO_ExpOp:$exp (HLO_ImagOp:$imag $val)),
(HLO_ExpOp:$nexp (HLO_NegOp $imag)))),
(HLO_ConstantOp : $two (ConstantSplat<"2.0"> $real))),
(HLO_DivOp
(HLO_MulOp
(HLO_SineOp $real),
(HLO_SubtractOp $nexp, $exp)), $two))>;
// Exponential can be lowered to an exponential on the real component and a
// sum of sinusoids of the imaginary component, which equates to a normal
// exponential operator multiplied by Euler's formula.
//
// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * Cos(b) + Exp(a) * iSin(b))
class HLO_ComparisonDirectionValue<string enumStr> :
ConstantAttr<HLO_ComparisonDirectionAttr, "::mlir::mhlo::ComparisonDirection::" # enumStr>;
def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val),
(HLO_ComplexOp
(HLO_MulOp
(HLO_CosineOp (HLO_ImagOp:$imag $val)),
(HLO_ExpOp:$exp (HLO_RealOp:$real $val))),
(HLO_MulOp (HLO_SineOp $imag), $exp))>;
foreach pair = [[HLO_ComparisonDirectionValue<"NE">, HLO_OrOp],
[HLO_ComparisonDirectionValue<"EQ">, HLO_AndOp]] in {
def : Pat<(HLO_CompareOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs, pair[0], $compare_type),
(pair[1]
(HLO_CompareOp (HLO_RealOp $lhs), (HLO_RealOp $rhs), pair[0], $compare_type),
(HLO_CompareOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs), pair[0], $compare_type))>;
}