Old engine for Continuous Time Bayesian Networks. Superseded by reCTBN. 🐍
https://github.com/madlabunimib/PyCTBN
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
293 lines
9.9 KiB
293 lines
9.9 KiB
4 years ago
|
"""
|
||
|
Tests for line search routines
|
||
|
"""
|
||
|
from numpy.testing import (assert_, assert_equal, assert_array_almost_equal,
|
||
|
assert_array_almost_equal_nulp, assert_warns,
|
||
|
suppress_warnings)
|
||
|
import scipy.optimize.linesearch as ls
|
||
|
from scipy.optimize.linesearch import LineSearchWarning
|
||
|
import numpy as np
|
||
|
|
||
|
|
||
|
def assert_wolfe(s, phi, derphi, c1=1e-4, c2=0.9, err_msg=""):
|
||
|
"""
|
||
|
Check that strong Wolfe conditions apply
|
||
|
"""
|
||
|
phi1 = phi(s)
|
||
|
phi0 = phi(0)
|
||
|
derphi0 = derphi(0)
|
||
|
derphi1 = derphi(s)
|
||
|
msg = "s = %s; phi(0) = %s; phi(s) = %s; phi'(0) = %s; phi'(s) = %s; %s" % (
|
||
|
s, phi0, phi1, derphi0, derphi1, err_msg)
|
||
|
|
||
|
assert_(phi1 <= phi0 + c1*s*derphi0, "Wolfe 1 failed: " + msg)
|
||
|
assert_(abs(derphi1) <= abs(c2*derphi0), "Wolfe 2 failed: " + msg)
|
||
|
|
||
|
|
||
|
def assert_armijo(s, phi, c1=1e-4, err_msg=""):
|
||
|
"""
|
||
|
Check that Armijo condition applies
|
||
|
"""
|
||
|
phi1 = phi(s)
|
||
|
phi0 = phi(0)
|
||
|
msg = "s = %s; phi(0) = %s; phi(s) = %s; %s" % (s, phi0, phi1, err_msg)
|
||
|
assert_(phi1 <= (1 - c1*s)*phi0, msg)
|
||
|
|
||
|
|
||
|
def assert_line_wolfe(x, p, s, f, fprime, **kw):
|
||
|
assert_wolfe(s, phi=lambda sp: f(x + p*sp),
|
||
|
derphi=lambda sp: np.dot(fprime(x + p*sp), p), **kw)
|
||
|
|
||
|
|
||
|
def assert_line_armijo(x, p, s, f, **kw):
|
||
|
assert_armijo(s, phi=lambda sp: f(x + p*sp), **kw)
|
||
|
|
||
|
|
||
|
def assert_fp_equal(x, y, err_msg="", nulp=50):
|
||
|
"""Assert two arrays are equal, up to some floating-point rounding error"""
|
||
|
try:
|
||
|
assert_array_almost_equal_nulp(x, y, nulp)
|
||
|
except AssertionError as e:
|
||
|
raise AssertionError("%s\n%s" % (e, err_msg))
|
||
|
|
||
|
|
||
|
class TestLineSearch(object):
|
||
|
# -- scalar functions; must have dphi(0.) < 0
|
||
|
def _scalar_func_1(self, s):
|
||
|
self.fcount += 1
|
||
|
p = -s - s**3 + s**4
|
||
|
dp = -1 - 3*s**2 + 4*s**3
|
||
|
return p, dp
|
||
|
|
||
|
def _scalar_func_2(self, s):
|
||
|
self.fcount += 1
|
||
|
p = np.exp(-4*s) + s**2
|
||
|
dp = -4*np.exp(-4*s) + 2*s
|
||
|
return p, dp
|
||
|
|
||
|
def _scalar_func_3(self, s):
|
||
|
self.fcount += 1
|
||
|
p = -np.sin(10*s)
|
||
|
dp = -10*np.cos(10*s)
|
||
|
return p, dp
|
||
|
|
||
|
# -- n-d functions
|
||
|
|
||
|
def _line_func_1(self, x):
|
||
|
self.fcount += 1
|
||
|
f = np.dot(x, x)
|
||
|
df = 2*x
|
||
|
return f, df
|
||
|
|
||
|
def _line_func_2(self, x):
|
||
|
self.fcount += 1
|
||
|
f = np.dot(x, np.dot(self.A, x)) + 1
|
||
|
df = np.dot(self.A + self.A.T, x)
|
||
|
return f, df
|
||
|
|
||
|
# --
|
||
|
|
||
|
def setup_method(self):
|
||
|
self.scalar_funcs = []
|
||
|
self.line_funcs = []
|
||
|
self.N = 20
|
||
|
self.fcount = 0
|
||
|
|
||
|
def bind_index(func, idx):
|
||
|
# Remember Python's closure semantics!
|
||
|
return lambda *a, **kw: func(*a, **kw)[idx]
|
||
|
|
||
|
for name in sorted(dir(self)):
|
||
|
if name.startswith('_scalar_func_'):
|
||
|
value = getattr(self, name)
|
||
|
self.scalar_funcs.append(
|
||
|
(name, bind_index(value, 0), bind_index(value, 1)))
|
||
|
elif name.startswith('_line_func_'):
|
||
|
value = getattr(self, name)
|
||
|
self.line_funcs.append(
|
||
|
(name, bind_index(value, 0), bind_index(value, 1)))
|
||
|
|
||
|
np.random.seed(1234)
|
||
|
self.A = np.random.randn(self.N, self.N)
|
||
|
|
||
|
def scalar_iter(self):
|
||
|
for name, phi, derphi in self.scalar_funcs:
|
||
|
for old_phi0 in np.random.randn(3):
|
||
|
yield name, phi, derphi, old_phi0
|
||
|
|
||
|
def line_iter(self):
|
||
|
for name, f, fprime in self.line_funcs:
|
||
|
k = 0
|
||
|
while k < 9:
|
||
|
x = np.random.randn(self.N)
|
||
|
p = np.random.randn(self.N)
|
||
|
if np.dot(p, fprime(x)) >= 0:
|
||
|
# always pick a descent direction
|
||
|
continue
|
||
|
k += 1
|
||
|
old_fv = float(np.random.randn())
|
||
|
yield name, f, fprime, x, p, old_fv
|
||
|
|
||
|
# -- Generic scalar searches
|
||
|
|
||
|
def test_scalar_search_wolfe1(self):
|
||
|
c = 0
|
||
|
for name, phi, derphi, old_phi0 in self.scalar_iter():
|
||
|
c += 1
|
||
|
s, phi1, phi0 = ls.scalar_search_wolfe1(phi, derphi, phi(0),
|
||
|
old_phi0, derphi(0))
|
||
|
assert_fp_equal(phi0, phi(0), name)
|
||
|
assert_fp_equal(phi1, phi(s), name)
|
||
|
assert_wolfe(s, phi, derphi, err_msg=name)
|
||
|
|
||
|
assert_(c > 3) # check that the iterator really works...
|
||
|
|
||
|
def test_scalar_search_wolfe2(self):
|
||
|
for name, phi, derphi, old_phi0 in self.scalar_iter():
|
||
|
s, phi1, phi0, derphi1 = ls.scalar_search_wolfe2(
|
||
|
phi, derphi, phi(0), old_phi0, derphi(0))
|
||
|
assert_fp_equal(phi0, phi(0), name)
|
||
|
assert_fp_equal(phi1, phi(s), name)
|
||
|
if derphi1 is not None:
|
||
|
assert_fp_equal(derphi1, derphi(s), name)
|
||
|
assert_wolfe(s, phi, derphi, err_msg="%s %g" % (name, old_phi0))
|
||
|
|
||
|
def test_scalar_search_wolfe2_with_low_amax(self):
|
||
|
def phi(alpha):
|
||
|
return (alpha - 5) ** 2
|
||
|
|
||
|
def derphi(alpha):
|
||
|
return 2 * (alpha - 5)
|
||
|
|
||
|
s, _, _, _ = assert_warns(LineSearchWarning,
|
||
|
ls.scalar_search_wolfe2, phi, derphi, amax=0.001)
|
||
|
assert_(s is None)
|
||
|
|
||
|
def test_scalar_search_armijo(self):
|
||
|
for name, phi, derphi, old_phi0 in self.scalar_iter():
|
||
|
s, phi1 = ls.scalar_search_armijo(phi, phi(0), derphi(0))
|
||
|
assert_fp_equal(phi1, phi(s), name)
|
||
|
assert_armijo(s, phi, err_msg="%s %g" % (name, old_phi0))
|
||
|
|
||
|
# -- Generic line searches
|
||
|
|
||
|
def test_line_search_wolfe1(self):
|
||
|
c = 0
|
||
|
smax = 100
|
||
|
for name, f, fprime, x, p, old_f in self.line_iter():
|
||
|
f0 = f(x)
|
||
|
g0 = fprime(x)
|
||
|
self.fcount = 0
|
||
|
s, fc, gc, fv, ofv, gv = ls.line_search_wolfe1(f, fprime, x, p,
|
||
|
g0, f0, old_f,
|
||
|
amax=smax)
|
||
|
assert_equal(self.fcount, fc+gc)
|
||
|
assert_fp_equal(ofv, f(x))
|
||
|
if s is None:
|
||
|
continue
|
||
|
assert_fp_equal(fv, f(x + s*p))
|
||
|
assert_array_almost_equal(gv, fprime(x + s*p), decimal=14)
|
||
|
if s < smax:
|
||
|
c += 1
|
||
|
assert_line_wolfe(x, p, s, f, fprime, err_msg=name)
|
||
|
|
||
|
assert_(c > 3) # check that the iterator really works...
|
||
|
|
||
|
def test_line_search_wolfe2(self):
|
||
|
c = 0
|
||
|
smax = 512
|
||
|
for name, f, fprime, x, p, old_f in self.line_iter():
|
||
|
f0 = f(x)
|
||
|
g0 = fprime(x)
|
||
|
self.fcount = 0
|
||
|
with suppress_warnings() as sup:
|
||
|
sup.filter(LineSearchWarning,
|
||
|
"The line search algorithm could not find a solution")
|
||
|
sup.filter(LineSearchWarning,
|
||
|
"The line search algorithm did not converge")
|
||
|
s, fc, gc, fv, ofv, gv = ls.line_search_wolfe2(f, fprime, x, p,
|
||
|
g0, f0, old_f,
|
||
|
amax=smax)
|
||
|
assert_equal(self.fcount, fc+gc)
|
||
|
assert_fp_equal(ofv, f(x))
|
||
|
assert_fp_equal(fv, f(x + s*p))
|
||
|
if gv is not None:
|
||
|
assert_array_almost_equal(gv, fprime(x + s*p), decimal=14)
|
||
|
if s < smax:
|
||
|
c += 1
|
||
|
assert_line_wolfe(x, p, s, f, fprime, err_msg=name)
|
||
|
assert_(c > 3) # check that the iterator really works...
|
||
|
|
||
|
def test_line_search_wolfe2_bounds(self):
|
||
|
# See gh-7475
|
||
|
|
||
|
# For this f and p, starting at a point on axis 0, the strong Wolfe
|
||
|
# condition 2 is met if and only if the step length s satisfies
|
||
|
# |x + s| <= c2 * |x|
|
||
|
f = lambda x: np.dot(x, x)
|
||
|
fp = lambda x: 2 * x
|
||
|
p = np.array([1, 0])
|
||
|
|
||
|
# Smallest s satisfying strong Wolfe conditions for these arguments is 30
|
||
|
x = -60 * p
|
||
|
c2 = 0.5
|
||
|
|
||
|
s, _, _, _, _, _ = ls.line_search_wolfe2(f, fp, x, p, amax=30, c2=c2)
|
||
|
assert_line_wolfe(x, p, s, f, fp)
|
||
|
|
||
|
s, _, _, _, _, _ = assert_warns(LineSearchWarning,
|
||
|
ls.line_search_wolfe2, f, fp, x, p,
|
||
|
amax=29, c2=c2)
|
||
|
assert_(s is None)
|
||
|
|
||
|
# s=30 will only be tried on the 6th iteration, so this won't converge
|
||
|
assert_warns(LineSearchWarning, ls.line_search_wolfe2, f, fp, x, p,
|
||
|
c2=c2, maxiter=5)
|
||
|
|
||
|
def test_line_search_armijo(self):
|
||
|
c = 0
|
||
|
for name, f, fprime, x, p, old_f in self.line_iter():
|
||
|
f0 = f(x)
|
||
|
g0 = fprime(x)
|
||
|
self.fcount = 0
|
||
|
s, fc, fv = ls.line_search_armijo(f, x, p, g0, f0)
|
||
|
c += 1
|
||
|
assert_equal(self.fcount, fc)
|
||
|
assert_fp_equal(fv, f(x + s*p))
|
||
|
assert_line_armijo(x, p, s, f, err_msg=name)
|
||
|
assert_(c >= 9)
|
||
|
|
||
|
# -- More specific tests
|
||
|
|
||
|
def test_armijo_terminate_1(self):
|
||
|
# Armijo should evaluate the function only once if the trial step
|
||
|
# is already suitable
|
||
|
count = [0]
|
||
|
|
||
|
def phi(s):
|
||
|
count[0] += 1
|
||
|
return -s + 0.01*s**2
|
||
|
s, phi1 = ls.scalar_search_armijo(phi, phi(0), -1, alpha0=1)
|
||
|
assert_equal(s, 1)
|
||
|
assert_equal(count[0], 2)
|
||
|
assert_armijo(s, phi)
|
||
|
|
||
|
def test_wolfe_terminate(self):
|
||
|
# wolfe1 and wolfe2 should also evaluate the function only a few
|
||
|
# times if the trial step is already suitable
|
||
|
|
||
|
def phi(s):
|
||
|
count[0] += 1
|
||
|
return -s + 0.05*s**2
|
||
|
|
||
|
def derphi(s):
|
||
|
count[0] += 1
|
||
|
return -1 + 0.05*2*s
|
||
|
|
||
|
for func in [ls.scalar_search_wolfe1, ls.scalar_search_wolfe2]:
|
||
|
count = [0]
|
||
|
r = func(phi, derphi, phi(0), None, derphi(0))
|
||
|
assert_(r[0] is not None, (r, func))
|
||
|
assert_(count[0] <= 2 + 2, (count, func))
|
||
|
assert_wolfe(r[0], phi, derphi, err_msg=str(func))
|