| import sympy |
| from sympy import S |
| from sympy.core.logic import fuzzy_and, fuzzy_not, fuzzy_or |
| |
| __all__ = ["FloorDiv", "ModularIndexing", "CleanDiv", "CeilDiv", "LShift", "RShift"] |
| |
| |
| class FloorDiv(sympy.Function): |
| """ |
| We maintain this so that: |
| 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. |
| 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) |
| """ |
| nargs = (2,) |
| precedence = 50 # precedence of mul # noqa: F811 |
| |
| # Default return type for SymPy assumptions. |
| # https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlers |
| is_real = True |
| |
| @property |
| def base(self): |
| return self.args[0] |
| |
| @property |
| def divisor(self): |
| return self.args[1] |
| |
| def _sympystr(self, printer): |
| base = printer.parenthesize(self.base, self.precedence) |
| divisor = printer.parenthesize(self.divisor, self.precedence) |
| return f"({base}//{divisor})" |
| |
| # SymPy assumptions based on argument types. |
| def _eval_is_real(self): |
| return fuzzy_or([self.base.is_real, self.divisor.is_real]) |
| |
| def _eval_is_integer(self): |
| return fuzzy_and([self.base.is_integer, self.divisor.is_integer]) |
| |
| # Automatic evaluation. |
| # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval |
| @classmethod |
| def eval(cls, base, divisor): |
| def check_supported_type(x): |
| if (x.is_integer is False and x.is_real is False and x.is_complex) or x.is_Boolean: |
| raise TypeError( |
| f"unsupported operand type(s) for //: " |
| f"'{type(base).__name__}' and '{type(divisor).__name__}'" |
| f", expected integer or real") |
| |
| check_supported_type(base) |
| check_supported_type(divisor) |
| |
| # We don't provide the same error message as in Python because SymPy |
| # makes it difficult to check the types. |
| if divisor.is_zero: |
| raise ZeroDivisionError("division by zero") |
| |
| if base.is_zero: |
| return sympy.S.Zero |
| if base.is_integer and divisor == 1: |
| return base |
| if base.is_real and divisor == 1: |
| return sympy.floor(base) |
| if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): |
| return base // divisor |
| if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance(divisor, (sympy.Integer, sympy.Float)): |
| return sympy.floor(base / divisor) |
| if isinstance(base, FloorDiv): |
| return FloorDiv(base.args[0], base.args[1] * divisor) |
| if isinstance(divisor, sympy.Rational) and divisor.p == 1: |
| return sympy.floor(base * divisor.q) |
| |
| if isinstance(base, sympy.Add): |
| for a in base.args: |
| gcd = sympy.gcd(a, divisor) |
| if gcd == divisor: |
| return FloorDiv(base - a, divisor) + a / gcd |
| |
| try: |
| gcd = sympy.gcd(base, divisor) |
| if gcd != 1: |
| return FloorDiv( |
| sympy.simplify(base / gcd), sympy.simplify(divisor / gcd) |
| ) |
| except sympy.PolynomialError: |
| pass # https://github.com/pytorch/pytorch/issues/108276 |
| |
| |
| class ModularIndexing(sympy.Function): |
| """ |
| ModularIndexing(a, b, c) => (a // b) % c |
| """ |
| |
| nargs = (3,) |
| is_integer = True |
| |
| @classmethod |
| def eval(cls, base, divisor, modulus): |
| if base == 0 or modulus == 1: |
| return sympy.Integer(0) |
| |
| if ( |
| isinstance(base, sympy.Integer) |
| and isinstance(divisor, sympy.Integer) |
| and isinstance(modulus, sympy.Integer) |
| ): |
| return (base // divisor) % modulus |
| |
| try: |
| if divisor != 1: |
| gcd = sympy.gcd(base, divisor) |
| if gcd != 1: |
| return ModularIndexing( |
| sympy.simplify(base / gcd), sympy.simplify(divisor / gcd), modulus |
| ) |
| except sympy.PolynomialError: |
| pass # https://github.com/pytorch/pytorch/issues/108276 |
| |
| if isinstance(base, sympy.Add): |
| new_terms = [] |
| all_positive = True |
| for term in base.args: |
| if sympy.gcd(term, modulus * divisor) != modulus * divisor: |
| if (isinstance(term, sympy.Integer) and term < 0) or ( |
| isinstance(term, sympy.Mul) |
| and isinstance(term.args[0], sympy.Integer) |
| and term.args[0] < 0 |
| ): |
| # workaround for https://github.com/openai/triton/issues/619, |
| # if there are negative terms, // produces wrong result |
| # TODO if https://github.com/openai/triton/issues/619 is fixed |
| # this optimization would become valid |
| all_positive = False |
| break |
| else: |
| new_terms.append(term) |
| |
| if len(new_terms) != len(base.args) and all_positive: |
| return ModularIndexing(sum(new_terms), divisor, modulus) |
| |
| if isinstance(base, FloorDiv): |
| return ModularIndexing(base.args[0], base.args[1] * divisor, modulus) |
| |
| class Where(sympy.Function): |
| """ |
| Good ol' ternary operator |
| """ |
| |
| nargs = (3,) |
| |
| @classmethod |
| def eval(cls, c, p, q): |
| if c == sympy.true: |
| return p |
| elif c == sympy.false: |
| return q |
| |
| class Mod(sympy.Function): |
| """ |
| We maintain this so that we avoid SymPy correctness issues, such as: |
| https://github.com/sympy/sympy/issues/25146 |
| """ |
| |
| nargs = (2,) |
| |
| @classmethod |
| def eval(cls, p, q): |
| # This was adapted from: sympy/core/mod.py |
| |
| if q.is_zero: |
| raise ZeroDivisionError("Modulo by zero") |
| # If either of them is NaN or infinite. |
| if p is S.NaN or q is S.NaN or p.is_finite is False or q.is_finite is False: |
| return S.NaN |
| # Three cases: |
| # 1. p == 0 |
| # 2. p is either q or -q |
| # 3. p is integer and q == 1 |
| if p is S.Zero or p in (q, -q) or (p.is_integer and q == 1): |
| return S.Zero |
| |
| # Evaluate if they are both literals. |
| if q.is_Number and p.is_Number: |
| return p % q |
| |
| # If q == 2, it's a matter of whether p is odd or even. |
| if q.is_Number and q == 2: |
| if p.is_even: |
| return S.Zero |
| if p.is_odd: |
| return S.One |
| |
| # If p is a multiple of q. |
| r = p / q |
| if r.is_integer: |
| return S.Zero |
| |
| # If p < q and its ratio is positive, then: |
| # - floor(p / q) = 0 |
| # - p % q = p - floor(p / q) * q = p |
| less = p < q |
| if less.is_Boolean and bool(less) and r.is_positive: |
| return p |
| |
| def _eval_is_integer(self): |
| p, q = self.args |
| return fuzzy_and([p.is_integer, q.is_integer, fuzzy_not(q.is_zero)]) # type: ignore[attr-defined] |
| |
| def _eval_is_nonnegative(self): |
| return True if self.args[1].is_positive else None # type: ignore[attr-defined] |
| |
| def _eval_is_nonpositive(self): |
| return True if self.args[1].is_negative else None # type: ignore[attr-defined] |
| |
| |
| class CleanDiv(FloorDiv): |
| """ |
| Div where we can assume no rounding. |
| This is to enable future optimizations. |
| """ |
| |
| pass |
| |
| |
| class CeilDiv(sympy.Function): |
| """ |
| Div used in indexing that rounds up. |
| """ |
| |
| is_integer = True |
| |
| def __new__(cls, base, divisor): |
| if sympy.gcd(base, divisor) == divisor: |
| return CleanDiv(base, divisor) |
| else: |
| return FloorDiv(base + (divisor - 1), divisor) |
| |
| |
| class LShift(sympy.Function): |
| @classmethod |
| def eval(cls, base, shift): |
| if shift < 0: |
| raise ValueError('negative shift count') |
| return base * 2 ** shift |
| |
| |
| class RShift(sympy.Function): |
| @classmethod |
| def eval(cls, base, shift): |
| if shift < 0: |
| raise ValueError('negative shift count') |
| return base // 2 ** shift |