Skip to content

Commit

Permalink
Fix failing test test_pow
Browse files Browse the repository at this point in the history
  • Loading branch information
debugger22 committed Aug 18, 2015
1 parent 85f7baf commit 8173ae2
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 18 deletions.
39 changes: 27 additions & 12 deletions sympy/assumptions/tests/test_refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,40 +19,55 @@ def test_Abs():


def test_pow():
assert refine((-1)**x, Q.even(x)) == 1
assert refine((-1)**x, Q.odd(x)) == -1
assert refine((-2)**x, Q.even(x)) == 2**x
x = Symbol('x', even=True)
assert refine((-1)**x) == 1
x = Symbol('x', odd=True)
assert refine((-1)**x) == -1
x = Symbol('x', even=True)
assert refine((-2)**x) == 2**x

# nested powers
x = Symbol('x')
assert refine(sqrt(x**2)) != Abs(x)
x = Symbol('x', complex=True)
assert refine(sqrt(x**2)) != Abs(x)
assert refine(sqrt(x**2), Q.complex(x)) != Abs(x)
assert refine(sqrt(x**2), Q.real(x)) == Abs(x)
x = Symbol('x', real=True)
assert refine(sqrt(x**2)) == Abs(x)
p = Symbol('p', positive=True)
assert refine(sqrt(p**2)) == p
x = Symbol('x')
assert refine((x**3)**(S(1)/3)) != x

assert refine((x**3)**(S(1)/3), Q.real(x)) != x
assert refine((x**3)**(S(1)/3), Q.positive(x)) == x

assert refine(sqrt(1/x), Q.real(x)) != 1/sqrt(x)
assert refine(sqrt(1/x), Q.positive(x)) == 1/sqrt(x)
x = Symbol('x', real=True)
assert refine((x**3)**(S(1)/3)) != x
x = Symbol('x', positive=True)
assert refine((x**3)**(S(1)/3)) == x
x = Symbol('x', real=True)
assert refine(sqrt(1/x)) != 1/sqrt(x)
x = Symbol('x', positive=True)
assert refine(sqrt(1/x)) == 1/sqrt(x)

# powers of (-1)
x = Symbol('x', even=True)
assert refine((-1)**(x + y), Q.even(x)) == (-1)**y
x = Symbol('x', odd=True)
z = Symbol('z', odd=True)
assert refine((-1)**(x + y + z), Q.odd(x) & Q.odd(z)) == (-1)**y
assert refine((-1)**(x + y + 1), Q.odd(x)) == (-1)**y
assert refine((-1)**(x + y + 2), Q.odd(x)) == (-1)**(y + 1)
x = Symbol('x')
assert refine((-1)**(x + 3)) == (-1)**(x + 1)

x = Symbol('x', integer=True)
assert refine((-1)**((-1)**x/2 - S.Half), Q.integer(x)) == (-1)**x
assert refine((-1)**((-1)**x/2 + S.Half), Q.integer(x)) == (-1)**(x + 1)
assert refine((-1)**((-1)**x/2 + 5*S.Half), Q.integer(x)) == (-1)**(x + 1)
assert refine((-1)**((-1)**x/2 - 7*S.Half), Q.integer(x)) == (-1)**(x + 1)
assert refine((-1)**((-1)**x/2 - 9*S.Half), Q.integer(x)) == (-1)**x

# powers of Abs
x = Symbol('x', real=True)
assert refine(Abs(x)**2, Q.real(x)) == x**2
assert refine(Abs(x)**3, Q.real(x)) == Abs(x)**3
x = Symbol('x')
assert refine(Abs(x)**2) == Abs(x)**2


Expand Down
17 changes: 13 additions & 4 deletions sympy/functions/elementary/complexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,20 +427,29 @@ def fdiff(self, argindex=1):
else:
raise ArgumentIndexError(self, argindex)

def _eval_refine(self, val=None):
from sympy.assumptions import Q, ask
def _eval_refine(self):
arg = self.args[0]
if arg.is_zero:
return S.Zero
if arg.is_nonnegative:
return arg
if arg.is_nonpositive:
return -arg
if arg.is_Add:
expr_list = []
for _arg in Add.make_args(arg):
if _arg.is_zero:
expr_list.append(S.Zero)
elif _arg.is_nonnegative:
expr_list.append(_arg)
elif _arg.is_nonpositive:
expr_list.append(-arg)
return Add(*expr_list)

@classmethod
def eval(cls, arg):
from sympy.simplify.simplify import signsimp
from sympy.assumptions import refine
from sympy.physics.units import Unit
if hasattr(arg, '_eval_Abs'):
obj = arg._eval_Abs()
if obj is not None:
Expand Down Expand Up @@ -479,7 +488,7 @@ def eval(cls, arg):
return (-base)**re(exponent)*exp(-S.Pi*im(exponent))
if isinstance(arg, exp):
return exp(re(arg.args[0]))
if arg.is_number or arg.is_Symbol or isinstance(arg, cls):
if arg.is_number or arg.is_Symbol or isinstance(arg, (cls, Unit)):
if arg.is_zero:
return S.Zero
if arg.is_nonnegative:
Expand Down
4 changes: 2 additions & 2 deletions sympy/integrals/tests/test_meijerint.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def t(a, b, arg, n):


def test_recursive():
from sympy import symbols
from sympy import symbols, refine
a, b, c = symbols('a b c', positive=True)
r = exp(-(x - a)**2)*exp(-(x - b)**2)
e = integrate(r, (x, 0, oo), meijerg=True)
Expand All @@ -112,7 +112,7 @@ def test_recursive():
+ (2*a + 2*b + c)**2/8)/4)
assert simplify(integrate(exp(-(x - a - b - c)**2), (x, 0, oo), meijerg=True)) == \
sqrt(pi)/2*(1 + erf(a + b + c))
assert simplify(integrate(exp(-(x + a + b + c)**2), (x, 0, oo), meijerg=True)) == \
assert simplify(refine(integrate(exp(-(x + a + b + c)**2), (x, 0, oo), meijerg=True))) == \
sqrt(pi)/2*(1 - erf(a + b + c))


Expand Down

0 comments on commit 8173ae2

Please sign in to comment.