Skip to content

Commit 865e102

Browse files
committed
Adding test coverage + heat equation tests
1 parent 22141f9 commit 865e102

10 files changed

+76
-25
lines changed

.github/workflows/ci.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ jobs:
1414
with:
1515
python-version: '3.13'
1616
architecture: 'x64'
17-
- name: Run Script
17+
- name: Run python tests
1818
run: |
19-
bash ./ci/script.sh
19+
bash ./ci/run_python_tests.sh

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616

1717
# Other files
1818
*.log
19+
/.coverage

ci/script.sh ci/run_python_tests.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ source venv/bin/activate
1111

1212
pip install -r requirements.txt
1313

14-
pytest
14+
pytest --cov=pdesolvers pdesolvers/tests/

pdesolvers/pdes/heat_1d.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -38,22 +38,20 @@ def __check_conditions(self):
3838

3939
if self.__left_boundary_temp is not None:
4040
err = np.abs(self.__left_boundary_temp(0) - self.__initial_temp(0))
41-
assert err < 1e-12
41+
assert err < 1e-12, f"Left boundary condition at t=0 does not match the initial condition."
4242

4343
if self.__right_boundary_temp is not None:
44-
err = np.abs(self.__right_boundary_temp(0) - self.__initial_temp(0))
45-
assert err < 1e-12
44+
err = np.abs(self.__right_boundary_temp(0) - self.__initial_temp(self.__length))
45+
assert err < 1e-12, f"Right boundary condition at t=0 does not match the initial condition."
4646

4747
@staticmethod
4848
def __validate_callable(func):
4949
if not callable(func):
5050
raise ValueError("Temperature conditions must be a callable function")
5151

52-
def generate_x_grid(self):
53-
return np.linspace(0, self.__length, self.__x_nodes)
54-
55-
def generate_t_grid(self):
56-
return np.linspace(0, self.__time, self.__t_nodes)
52+
@staticmethod
53+
def generate_grid(value, nodes):
54+
return np.linspace(0, value, nodes)
5755

5856
@property
5957
def length(self):

pdesolvers/solvers/black_scholes_solvers.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@ def solve(self):
2424
if self.equation.t_nodes is None:
2525
dt = 0.9 * dt_max
2626
self.equation.t_nodes = int(self.equation.expiry/dt)
27-
dt = self.equation.expiry / self.equation.t_nodes # to ensure that the expiration time is integer time steps away
27+
dt = self.equation.expiry / self.equation.t_nodes
2828
else:
29-
# possible fix - set a check to see that user-defined value is within cfl condition
3029
dt = T[1] - T[0]
3130

3231
if dt > dt_max:

pdesolvers/solvers/heat_solvers.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,29 @@ def solve(self):
1616
:return: the solver instance with the computed temperature values
1717
"""
1818

19-
x = self.equation.generate_x_grid()
20-
dx = x[1] - x[0]
19+
x = self.equation.generate_grid(self.equation.length, self.equation.x_nodes)
20+
t = self.equation.generate_grid(self.equation.time, self.equation.t_nodes)
2121

22+
dx = x[1] - x[0]
2223
dt_max = 0.5 * (dx**2) / self.equation.k
23-
dt = 0.8 * dt_max
24-
time_step = int(self.equation.time/dt)
25-
self.equation.t_nodes = time_step
2624

27-
t = np.linspace(0, self.equation.time, self.equation.t_nodes)
25+
if self.equation.t_nodes is None:
26+
dt = 0.8 * dt_max
27+
self.equation.t_nodes = int(self.equation.time/dt)
28+
dt = self.equation.time / self.equation.t_nodes
29+
else:
30+
dt = t[1] - t[0]
2831

29-
u = np.zeros((time_step, self.equation.x_nodes))
32+
if dt > dt_max:
33+
raise ValueError("User-defined t nodes is too small and exceeds the CFL condition. Possible action: Increase number of t nodes for stability!")
34+
35+
u = np.zeros((self.equation.t_nodes, self.equation.x_nodes))
3036

3137
u[0, :] = self.equation.get_initial_temp(x)
3238
u[:, 0] = self.equation.get_left_boundary(t)
3339
u[:, -1] = self.equation.get_right_boundary(t)
3440

35-
for tau in range(0, time_step-1):
41+
for tau in range(0, self.equation.t_nodes-1):
3642
for i in range(1, self.equation.x_nodes - 1):
3743
u[tau+1,i] = u[tau, i] + (dt * self.equation.k * (u[tau, i-1] - 2 * u[tau, i] + u[tau, i+1]) / dx**2)
3844

@@ -49,8 +55,8 @@ def solve(self):
4955
:return: the solver instance with the computed temperature values
5056
"""
5157

52-
x = self.equation.generate_x_grid()
53-
t = self.equation.generate_t_grid()
58+
x = self.equation.generate_grid(self.equation.length, self.equation.x_nodes)
59+
t = self.equation.generate_grid(self.equation.time, self.equation.t_nodes)
5460

5561
dx = x[1] - x[0]
5662
dt = t[1] - t[0]

pdesolvers/tests/__init__.py

Whitespace-only changes.

pdesolvers/tests/test_black_scholes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def test_convergence_between_interpolated_data(self):
8686

8787
diff = np.abs(data1 - data2)
8888

89-
assert diff < 1e-4
89+
assert np.max(diff) < 1e-4
9090

9191

9292

pdesolvers/tests/test_heat_solvers.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import pytest
2+
import numpy as np
3+
4+
import pdesolvers.pdes.heat_1d as heat
5+
import pdesolvers.solvers.heat_solvers as solver
6+
import pdesolvers.utils.utility as utility
7+
8+
class TestHeatSolvers:
9+
10+
def setup_method(self):
11+
self.equation = (heat.HeatEquation(1, 100,30,10000, 0.01)
12+
.set_initial_temp(lambda x: 10 * np.sin(2 * np.pi * x) + 15)
13+
.set_left_boundary_temp(lambda t: t + 15)
14+
.set_right_boundary_temp(lambda t: 15))
15+
16+
def test_check_terminal_and_boundary_conditions_at_time_zero(self):
17+
length = self.equation.length
18+
assert abs(self.equation.get_left_boundary(0) - self.equation.get_initial_temp(0)) < 1e-12, "Left boundary condition failed"
19+
assert abs(self.equation.get_right_boundary(0) - self.equation.get_initial_temp(length)) < 1e-12, "Right boundary condition failed"
20+
21+
# explicit method tests
22+
23+
def test_check_absolute_difference_between_two_results(self):
24+
result1 = solver.Heat1DExplicitSolver(self.equation).solve().get_result()
25+
result2 = solver.Heat1DCNSolver(self.equation).solve().get_result()
26+
27+
diff = result1 - result2
28+
29+
assert np.max(np.abs(diff)) < 1e-2
30+
31+
def test_convergence_between_interpolated_data(self):
32+
result1 = solver.Heat1DExplicitSolver(self.equation).solve()
33+
result2 = solver.Heat1DCNSolver(self.equation).solve()
34+
u1 = result1.get_result()
35+
u2 = result2.get_result()
36+
37+
data1 = utility.RBFInterpolator(u1, 0.1, 0.03).interpolate(0.2,0.9)
38+
data2 = utility.RBFInterpolator(u2, 0.1, 0.03).interpolate(0.2,0.9)
39+
40+
diff = np.abs(data1 - data2)
41+
42+
assert np.max(diff) < 1e-4
43+
44+
45+
46+
# crank-nicolson method tests

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ matplotlib==3.9.2
22
numpy==2.1.3
33
scipy==1.14.1
44
pandas==2.2.3
5-
pytest==8.3.4
5+
pytest==8.3.4
6+
pytest-cov==6.0.0

0 commit comments

Comments
 (0)