Skip to content

Commit e60507a

Browse files
committed
FEATURE: renaming functions and adding tests
1 parent 5637652 commit e60507a

File tree

2 files changed

+47
-36
lines changed

2 files changed

+47
-36
lines changed

Diff for: pdesolvers/tests/test_black_scholes.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pdesolvers.pdes.black_scholes as bse
55
import pdesolvers.solvers.black_scholes_solvers as solver
6+
import pdesolvers.utils.utility as utility
67

78
class TestBlackScholesSolvers:
89

@@ -74,10 +75,19 @@ def test_check_absolute_difference_between_two_results(self):
7475

7576
assert np.max(np.abs(diff)) < 1e-2
7677

77-
# X, Y = np.meshgrid(result1.t_grid, result1.s_grid)
78-
#
79-
# fig = plt.figure(figsize=(10,6))
80-
# ax = fig.add_subplot(111, projection='3d')
81-
# surf = ax.plot_surface(X, Y, diff, cmap='viridis')
82-
# print(np.max(np.abs(diff)))
83-
# plt.show()
78+
def test_convergence_between_interpolated_data(self):
79+
result1 = solver.BlackScholesExplicitSolver(self.equation).solve()
80+
result2 = solver.BlackScholesCNSolver(self.equation).solve()
81+
u1 = result1.get_result()
82+
u2 = result2.get_result()
83+
84+
data1 = utility.RBFInterpolator(u1, 0.1, 0.03).interpolate(4,20)
85+
data2 = utility.RBFInterpolator(u2,0.1, 0.03).interpolate(4,20)
86+
87+
diff = np.abs(data1 - data2)
88+
89+
assert diff < 1e-4
90+
91+
92+
93+

Diff for: pdesolvers/utils/utility.py

+30-29
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
class RBFInterpolator:
55

6-
def __init__(self, z, x, y, hx, hy):
6+
def __init__(self, z, hx, hy):
77
"""
88
Initializes the RBF Interpolator.
99
@@ -14,43 +14,41 @@ def __init__(self, z, x, y, hx, hy):
1414
:param hy: Grid spacing in the y-direction.
1515
"""
1616

17-
self._z = z
18-
self._x = x
19-
self._y = y
20-
self._hx = hx
21-
self._hy = hy
22-
self._nx, self._ny = z.shape
17+
self.__z = z
18+
self.__hx = hx
19+
self.__hy = hy
20+
self.__nx, self._ny = z.shape
2321

24-
def _get_coordinates(self):
22+
def __get_coordinates(self, x, y):
2523
"""
2624
Determines the x and y coordinates of the bottom-left corner of the grid cell
2725
2826
:return: A tuple containing the coordinates and its corresponding indices
2927
"""
3028

3129
# gets the grid steps to x
32-
i_minus_star = int(np.floor(self._x / self._hx))
33-
if i_minus_star > self._nx - 1:
30+
i_minus_star = int(np.floor(x / self.__hx))
31+
if i_minus_star > self.__nx - 1:
3432
raise Exception("x is out of bounds")
3533

3634
# final i index for interpolation
37-
i_minus = i_minus_star if i_minus_star < self._nx - 1 else self._nx - 1
35+
i_minus = i_minus_star if i_minus_star < self.__nx - 1 else self.__nx - 1
3836

3937
# gets the grid steps to y
40-
j_minus_star = int(np.floor(self._y / self._hy))
38+
j_minus_star = int(np.floor(y / self.__hy))
4139
if j_minus_star > self._ny - 1:
4240
raise Exception("y is out of bounds")
4341

4442
# final j index for interpolation
4543
j_minus = j_minus_star if j_minus_star < self._ny - 1 else self._ny - 1
4644

4745
# computes the coordinates at the computed indices
48-
x_minus = i_minus * self._hx
49-
y_minus = j_minus * self._hy
46+
x_minus = i_minus * self.__hx
47+
y_minus = j_minus * self.__hy
5048

5149
return x_minus, y_minus, i_minus, j_minus
5250

53-
def _euclidean_distances(self, x_minus, y_minus):
51+
def __euclidean_distances(self, x_minus, y_minus, x, y):
5452
"""
5553
Calculates Euclidean distances between (x,y) and the surrounding grid points in the unit cell
5654
@@ -60,15 +58,15 @@ def _euclidean_distances(self, x_minus, y_minus):
6058
[bottom left, top left, bottom right, top right]
6159
"""
6260

63-
bottom_left = np.sqrt((x_minus - self._x) ** 2 + (y_minus - self._y) ** 2)
64-
top_left = np.sqrt((x_minus - self._x) ** 2 + (y_minus + self._hy - self._y) ** 2)
65-
bottom_right = np.sqrt((x_minus + self._hx - self._x) ** 2 + (y_minus - self._y) ** 2)
66-
top_right = np.sqrt((x_minus + self._hx - self._x) ** 2 + (y_minus + self._hy - self._y) ** 2)
61+
bottom_left = np.sqrt((x_minus - x) ** 2 + (y_minus - y) ** 2)
62+
top_left = np.sqrt((x_minus - x) ** 2 + (y_minus + self.__hy - y) ** 2)
63+
bottom_right = np.sqrt((x_minus + self.__hx - x) ** 2 + (y_minus - y) ** 2)
64+
top_right = np.sqrt((x_minus + self.__hx - x) ** 2 + (y_minus + self.__hy - y) ** 2)
6765

6866
return bottom_left, top_left, bottom_right, top_right
6967

7068
@staticmethod
71-
def _rbf(d, gamma):
69+
def __rbf(d, gamma):
7270
"""
7371
Computes the Radial Basis Function (RBF) for a given distance and gamma
7472
@@ -78,27 +76,30 @@ def _rbf(d, gamma):
7876
"""
7977
return np.exp(-gamma * d ** 2)
8078

81-
def rbf_interpolate(self):
79+
def interpolate(self, x, y):
8280
"""
8381
Performs the Radial Basis function (RBF) interpolation for the point (x,y)
8482
8583
:return: the interpolated value at (x,y)
8684
"""
8785

88-
x_minus, y_minus, i_minus, j_minus = self._get_coordinates()
86+
x_minus, y_minus, i_minus, j_minus = self.__get_coordinates(x, y)
8987

90-
distances = self._euclidean_distances(x_minus, y_minus)
88+
distances = self.__euclidean_distances(x_minus, y_minus, x, y)
9189

92-
h_diag_squared = self._hx ** 2 + self._hy ** 2
90+
h_diag_squared = self.__hx ** 2 + self.__hy ** 2
9391
gamma = -np.log(0.005) / h_diag_squared
9492

95-
rbf_weights = [self._rbf(d, gamma) for d in distances]
93+
rbf_weights = [self.__rbf(d, gamma) for d in distances]
9694

9795
sum_rbf = np.sum(rbf_weights)
98-
interpolated = rbf_weights[0] * self._z[i_minus, j_minus]
99-
interpolated += rbf_weights[1] * self._z[i_minus, j_minus + 1]
100-
interpolated += rbf_weights[2] * self._z[i_minus + 1, j_minus]
101-
interpolated += rbf_weights[3] * self._z[i_minus + 1, j_minus + 1]
96+
interpolated = rbf_weights[0] * self.__z[i_minus, j_minus]
97+
interpolated += rbf_weights[1] * self.__z[i_minus, j_minus + 1]
98+
interpolated += rbf_weights[2] * self.__z[i_minus + 1, j_minus]
99+
interpolated += rbf_weights[3] * self.__z[i_minus + 1, j_minus + 1]
102100
interpolated /= sum_rbf
103101

104102
return interpolated
103+
104+
def interpolate_all(self):
105+
pass

0 commit comments

Comments
 (0)