Skip to content

Commit 89fb01f

Browse files
authored
Merge pull request #38 from chemardes/feature/adding-rbf-interpolation
FEATURE: adding rbf interpolation
2 parents 0bede35 + e60507a commit 89fb01f

File tree

9 files changed

+187
-18
lines changed

9 files changed

+187
-18
lines changed

pdesolvers/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .pdes import *
22
from .solution import *
3-
from .solvers import *
3+
from .solvers import *
4+
from .utils import *

pdesolvers/main.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,18 @@ def main():
1919
equation2 = pde.BlackScholesEquation('call', 300, 1, 0.2, 0.05, 100, 100, 20000)
2020

2121
solver1 = pde.BlackScholesCNSolver(equation2)
22-
res1 = solver1.solve()
23-
# res2 = solver2.solve()
22+
solver2 = pde.BlackScholesExplicitSolver(equation2)
23+
res1 = solver1.solve().get_result()
24+
res2 = solver2.solve().get_result()
2425

25-
print(res1.get_result())
26-
res1.plot()
26+
interpolator1 = pde.RBFInterpolator(res1, 0.8, 200,0.1, 0.03)
27+
interpolator2 = pde.RBFInterpolator(res2, 0.8, 200,0.1, 0.03)
28+
print(interpolator1.rbf_interpolate())
29+
print(interpolator2.rbf_interpolate())
30+
31+
32+
# print(res.shape)
33+
# res1.plot()
2734

2835
if __name__ == "__main__":
2936
main()

pdesolvers/solution/solution.py

+20
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ def __init__(self, result, x_grid, t_grid):
1010
self.t_grid = t_grid
1111

1212
def plot(self):
13+
"""
14+
Generates a 3D surface plot of the temperature distribution across a grid of space and time
15+
16+
:return: 3D surface plot
17+
"""
1318

1419
if self.result is None:
1520
raise RuntimeError("Solution has not been computed - please run the solver.")
@@ -31,6 +36,11 @@ def plot(self):
3136
plt.show()
3237

3338
def get_result(self):
39+
"""
40+
Gets the grid of computed temperature values
41+
42+
:return: grid result
43+
"""
3444
return self.result
3545

3646
class SolutionBlackScholes:
@@ -40,6 +50,11 @@ def __init__(self, result, s_grid, t_grid):
4050
self.t_grid = t_grid
4151

4252
def plot(self):
53+
"""
54+
Generates a 3D surface plot of the option values across a grid of asset prices and time
55+
56+
:return: 3D surface plot
57+
"""
4358

4459
X, Y = np.meshgrid(self.t_grid, self.s_grid)
4560

@@ -57,4 +72,9 @@ def plot(self):
5772
plt.show()
5873

5974
def get_result(self):
75+
"""
76+
Gets the grid of computed option prices
77+
78+
:return: grid result
79+
"""
6080
return self.result

pdesolvers/solvers/black_scholes_solvers.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,13 @@ def __init__(self, equation: bse.BlackScholesEquation):
1212
def solve(self):
1313
"""
1414
This method solves the Black-Scholes equation using the explicit finite difference method
15+
1516
:return: the solver instance with the computed option values
1617
"""
1718

19+
S = self.equation.generate_asset_grid()
20+
T = self.equation.generate_time_grid()
21+
1822
dt_max = 1/((self.equation.s_nodes**2) * (self.equation.sigma**2)) # cfl condition to ensure stability
1923

2024
if self.equation.t_nodes is None:
@@ -23,15 +27,11 @@ def solve(self):
2327
dt = self.equation.expiry / self.equation.t_nodes # to ensure that the expiration time is integer time steps away
2428
else:
2529
# possible fix - set a check to see that user-defined value is within cfl condition
26-
dt = self.equation.expiry / self.equation.t_nodes
30+
dt = T[1] - T[0]
2731

2832
if dt > dt_max:
2933
raise ValueError("User-defined t nodes is too small and exceeds the CFL condition. Possible action: Increase number of t nodes for stability!")
3034

31-
32-
S = self.equation.generate_asset_grid()
33-
T = self.equation.generate_time_grid()
34-
3535
dS = S[1] - S[0]
3636

3737
V = np.zeros((self.equation.s_nodes + 1, self.equation.t_nodes + 1))
@@ -47,7 +47,7 @@ def solve(self):
4747
for tau in reversed(range(self.equation.t_nodes)):
4848
for i in range(1, self.equation.s_nodes):
4949
delta = (V[i+1, tau+1] - V[i-1, tau+1]) / (2 * dS)
50-
gamma = (V[i+1, tau+1] - 2 * V[i,tau+1] + V[i-1, tau+1]) / dS ** 2
50+
gamma = (V[i+1, tau+1] - 2 * V[i,tau+1] + V[i-1, tau+1]) / (dS ** 2)
5151
theta = -0.5 * ( self.equation.sigma ** 2) * (S[i] ** 2) * gamma - self.equation.rate * S[i] * delta + self.equation.rate * V[i, tau+1]
5252
V[i, tau] = V[i, tau + 1] - (theta * dt)
5353

@@ -59,6 +59,14 @@ def solve(self):
5959
return sol.SolutionBlackScholes(V,S,T)
6060

6161
def __set_boundary_conditions(self, T, tau):
62+
"""
63+
Sets the boundary conditions for the Black-Scholes Equation based on option type
64+
65+
:param T: grid of time steps
66+
:param tau: index of current time step
67+
:return: a tuple representing the boundary values for the given time step
68+
"""
69+
6270
lower_boundary = None
6371
upper_boundary = None
6472
if self.equation.option_type == 'call':
@@ -76,6 +84,11 @@ def __init__(self, equation: bse.BlackScholesEquation):
7684
self.equation = equation
7785

7886
def solve(self):
87+
"""
88+
This method solves the Black-Scholes equation using the Crank-Nicolson method
89+
90+
:return: the solver instance with the computed option values
91+
"""
7992

8093
S = self.equation.generate_asset_grid()
8194
T = self.equation.generate_time_grid()

pdesolvers/solvers/heat_solvers.py

+10
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ def __init__(self, equation: heat.HeatEquation):
1010
self.equation = equation
1111

1212
def solve(self):
13+
"""
14+
This method solves the heat (diffusion) equation using the explicit finite difference method
15+
16+
:return: the solver instance with the computed temperature values
17+
"""
1318

1419
x = self.equation.generate_x_grid()
1520
dx = x[1] - x[0]
@@ -38,6 +43,11 @@ def __init__(self, equation: heat.HeatEquation):
3843
self.equation = equation
3944

4045
def solve(self):
46+
"""
47+
This method solves the heat (diffusion) equation using the Crank Nicolson method
48+
49+
:return: the solver instance with the computed temperature values
50+
"""
4151

4252
x = self.equation.generate_x_grid()
4353
t = self.equation.generate_t_grid()

pdesolvers/tests/test_black_scholes.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33

44
import pdesolvers.pdes.black_scholes as bse
55
import pdesolvers.solvers.black_scholes_solvers as solver
6+
import pdesolvers.utils.utility as utility
67

78
class TestBlackScholesSolvers:
89

910
def setup_method(self):
10-
self.equation = bse.BlackScholesEquation('call', 300, 1, 0.2, 0.05, 100, 500, 20000)
11+
self.equation = bse.BlackScholesEquation('call', 300, 1, 0.2, 0.05, 100, 100, 2000)
1112

1213
# explicit method tests
1314

@@ -72,10 +73,21 @@ def test_check_absolute_difference_between_two_results(self):
7273
u2 = result2.get_result()
7374
diff = u1 - u2
7475

75-
# X, Y = np.meshgrid(result1.t_grid, result1.s_grid)
76+
assert np.max(np.abs(diff)) < 1e-2
77+
78+
def test_convergence_between_interpolated_data(self):
79+
result1 = solver.BlackScholesExplicitSolver(self.equation).solve()
80+
result2 = solver.BlackScholesCNSolver(self.equation).solve()
81+
u1 = result1.get_result()
82+
u2 = result2.get_result()
83+
84+
data1 = utility.RBFInterpolator(u1, 0.1, 0.03).interpolate(4,20)
85+
data2 = utility.RBFInterpolator(u2,0.1, 0.03).interpolate(4,20)
86+
87+
diff = np.abs(data1 - data2)
88+
89+
assert diff < 1e-4
90+
91+
92+
7693

77-
# fig = plt.figure(figsize=(10,6))
78-
# ax = fig.add_subplot(111, projection='3d')
79-
# surf = ax.plot_surface(X, Y, diff, cmap='viridis')
80-
print(np.max(np.abs(diff)))
81-
# plt.show()

pdesolvers/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .utility import RBFInterpolator

pdesolvers/utils/utility.py

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import numpy as np
2+
3+
4+
class RBFInterpolator:
5+
6+
def __init__(self, z, hx, hy):
7+
"""
8+
Initializes the RBF Interpolator.
9+
10+
:param z: 2D array of values at the grid points.
11+
:param x: x-coordinate of the point to interpolate.
12+
:param y: y-coordinate of the point to interpolate.
13+
:param hx: Grid spacing in the x-direction.
14+
:param hy: Grid spacing in the y-direction.
15+
"""
16+
17+
self.__z = z
18+
self.__hx = hx
19+
self.__hy = hy
20+
self.__nx, self._ny = z.shape
21+
22+
def __get_coordinates(self, x, y):
23+
"""
24+
Determines the x and y coordinates of the bottom-left corner of the grid cell
25+
26+
:return: A tuple containing the coordinates and its corresponding indices
27+
"""
28+
29+
# gets the grid steps to x
30+
i_minus_star = int(np.floor(x / self.__hx))
31+
if i_minus_star > self.__nx - 1:
32+
raise Exception("x is out of bounds")
33+
34+
# final i index for interpolation
35+
i_minus = i_minus_star if i_minus_star < self.__nx - 1 else self.__nx - 1
36+
37+
# gets the grid steps to y
38+
j_minus_star = int(np.floor(y / self.__hy))
39+
if j_minus_star > self._ny - 1:
40+
raise Exception("y is out of bounds")
41+
42+
# final j index for interpolation
43+
j_minus = j_minus_star if j_minus_star < self._ny - 1 else self._ny - 1
44+
45+
# computes the coordinates at the computed indices
46+
x_minus = i_minus * self.__hx
47+
y_minus = j_minus * self.__hy
48+
49+
return x_minus, y_minus, i_minus, j_minus
50+
51+
def __euclidean_distances(self, x_minus, y_minus, x, y):
52+
"""
53+
Calculates Euclidean distances between (x,y) and the surrounding grid points in the unit cell
54+
55+
:param x_minus: x-coordinate of the bottom-left corner of the grid
56+
:param y_minus: y-coordinate of the bottom-left corner of the grid
57+
:return: returns tuple with the Euclidean distances to the surrounding grid points:
58+
[bottom left, top left, bottom right, top right]
59+
"""
60+
61+
bottom_left = np.sqrt((x_minus - x) ** 2 + (y_minus - y) ** 2)
62+
top_left = np.sqrt((x_minus - x) ** 2 + (y_minus + self.__hy - y) ** 2)
63+
bottom_right = np.sqrt((x_minus + self.__hx - x) ** 2 + (y_minus - y) ** 2)
64+
top_right = np.sqrt((x_minus + self.__hx - x) ** 2 + (y_minus + self.__hy - y) ** 2)
65+
66+
return bottom_left, top_left, bottom_right, top_right
67+
68+
@staticmethod
69+
def __rbf(d, gamma):
70+
"""
71+
Computes the Radial Basis Function (RBF) for a given distance and gamma
72+
73+
:param d: the Euclidean distance to a grid point
74+
:param gamma: gamma parameter
75+
:return: the RBF value for the distance d
76+
"""
77+
return np.exp(-gamma * d ** 2)
78+
79+
def interpolate(self, x, y):
80+
"""
81+
Performs the Radial Basis function (RBF) interpolation for the point (x,y)
82+
83+
:return: the interpolated value at (x,y)
84+
"""
85+
86+
x_minus, y_minus, i_minus, j_minus = self.__get_coordinates(x, y)
87+
88+
distances = self.__euclidean_distances(x_minus, y_minus, x, y)
89+
90+
h_diag_squared = self.__hx ** 2 + self.__hy ** 2
91+
gamma = -np.log(0.005) / h_diag_squared
92+
93+
rbf_weights = [self.__rbf(d, gamma) for d in distances]
94+
95+
sum_rbf = np.sum(rbf_weights)
96+
interpolated = rbf_weights[0] * self.__z[i_minus, j_minus]
97+
interpolated += rbf_weights[1] * self.__z[i_minus, j_minus + 1]
98+
interpolated += rbf_weights[2] * self.__z[i_minus + 1, j_minus]
99+
interpolated += rbf_weights[3] * self.__z[i_minus + 1, j_minus + 1]
100+
interpolated /= sum_rbf
101+
102+
return interpolated
103+
104+
def interpolate_all(self):
105+
pass
File renamed without changes.

0 commit comments

Comments
 (0)