Skip to content

Commit e7f596f

Browse files
committed
FEATURE: adding enums
1 parent 71e94ca commit e7f596f

File tree

8 files changed

+31
-15
lines changed

8 files changed

+31
-15
lines changed

pdesolvers/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .pdes import *
22
from .solution import *
33
from .solvers import *
4-
from .utils import *
4+
from .utils import *
5+
from .enums import *

pdesolvers/enums/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .option_type import OptionType

pdesolvers/enums/option_type.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from enum import Enum
2+
3+
class OptionType(Enum):
4+
EUROPEAN_CALL = 1
5+
EUROPEAN_PUT = 2

pdesolvers/main.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import numpy as np
22
import pdesolvers as pde
33

4-
54
def main():
65

76
# testing for heat equation
@@ -16,7 +15,7 @@ def main():
1615
# solver2 = pde.Heat1DExplicitSolver(equation1)
1716

1817
# testing for bse
19-
equation2 = pde.BlackScholesEquation('call', 300, 1, 0.2, 0.05, 100, 100, 20000)
18+
equation2 = pde.BlackScholesEquation(pde.OptionType.EUROPEAN_CALL, 300, 1, 0.2, 0.05, 100, 100, 20000)
2019

2120
solver1 = pde.BlackScholesCNSolver(equation2)
2221
# solver2 = pde.BlackScholesExplicitSolver(equation2)

pdesolvers/pdes/black_scholes.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import numpy as np
2+
from pdesolvers.enums.option_type import OptionType
23

34
class BlackScholesEquation:
45

5-
def __init__(self, option_type, S_max, expiry, sigma, r, K, s_nodes=1, t_nodes=None):
6+
def __init__(self, option_type: OptionType, S_max, expiry, sigma, r, K, s_nodes=1, t_nodes=None):
67
"""
78
Initialises the solver with the necessary parameters
89
@@ -16,6 +17,8 @@ def __init__(self, option_type, S_max, expiry, sigma, r, K, s_nodes=1, t_nodes=N
1617
:param t_nodes: number of time nodes
1718
"""
1819

20+
if not isinstance(option_type, OptionType):
21+
raise TypeError(f"Option type must be of type OptionType enum" )
1922
self.__option_type = option_type
2023
self.__S_max = S_max
2124
self.__expiry = expiry

pdesolvers/solution/solution.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pdesolvers.pdes.black_scholes as bse
21
import numpy as np
32
from matplotlib import pyplot as plt
43

pdesolvers/solvers/black_scholes_solvers.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pdesolvers.solution as sol
55
import pdesolvers.pdes.black_scholes as bse
6+
import pdesolvers.enums.option_type as enum
67

78
class BlackScholesExplicitSolver:
89

@@ -36,9 +37,9 @@ def solve(self):
3637
V = np.zeros((self.equation.s_nodes + 1, self.equation.t_nodes + 1))
3738

3839
# setting terminal condition
39-
if self.equation.option_type == 'call':
40+
if self.equation.option_type == enum.OptionType.EUROPEAN_CALL:
4041
V[:,-1] = np.maximum((S - self.equation.strike_price), 0)
41-
elif self.equation.option_type == 'put':
42+
elif self.equation.option_type == enum.OptionType.EUROPEAN_PUT:
4243
V[:,-1] = np.maximum((self.equation.strike_price - S), 0)
4344
else:
4445
raise ValueError("Invalid option type - please choose between call/put")
@@ -74,10 +75,10 @@ def __set_boundary_conditions(self, T, tau):
7475

7576
lower_boundary = None
7677
upper_boundary = None
77-
if self.equation.option_type == 'call':
78+
if self.equation.option_type == enum.OptionType.EUROPEAN_CALL:
7879
lower_boundary = 0
7980
upper_boundary = self.equation.S_max - self.equation.strike_price * np.exp(-self.equation.rate * (self.equation.expiry - T[tau]))
80-
elif self.equation.option_type == 'put':
81+
elif self.equation.option_type == enum.OptionType.EUROPEAN_PUT:
8182
lower_boundary = self.equation.strike_price * np.exp(-self.equation.rate * (self.equation.expiry - T[tau]))
8283
upper_boundary = 0
8384

@@ -130,14 +131,14 @@ def solve(self):
130131

131132

132133
# setting terminal condition (for all values of S at time T)
133-
if self.equation.option_type == 'call':
134+
if self.equation.option_type == enum.OptionType.EUROPEAN_CALL:
134135
V[:,-1] = np.maximum((S - self.equation.strike_price), 0)
135136

136137
# setting boundary conditions (for all values of t at asset prices S=0 and S=Smax)
137138
V[0, :] = 0
138139
V[-1, :] = S[-1] - self.equation.strike_price * np.exp(-self.equation.rate * (self.equation.expiry - T))
139140

140-
elif self.equation.option_type == 'put':
141+
elif self.equation.option_type == enum.OptionType.EUROPEAN_PUT:
141142
V[:,-1] = np.maximum((self.equation.strike_price - S), 0)
142143
V[0, :] = self.equation.strike_price * np.exp(-self.equation.rate * (self.equation.expiry - T))
143144
V[-1, :] = 0

pdesolvers/tests/test_black_scholes.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,18 @@
44
import pdesolvers.pdes.black_scholes as bse
55
import pdesolvers.solvers.black_scholes_solvers as solver
66
import pdesolvers.utils.utility as utility
7+
import pdesolvers.enums.option_type as enum
8+
9+
class TestBlackScholesEquation:
10+
11+
def test_check_invalid_option_type_input(self):
12+
with pytest.raises(TypeError, match="Option type must be of type OptionType enum"):
13+
self.equation = bse.BlackScholesEquation('woo', 300, 1, 0.2, 0.05, 100, 100, 2000)
714

815
class TestBlackScholesSolvers:
916

1017
def setup_method(self):
11-
self.equation = bse.BlackScholesEquation('call', 300, 1, 0.2, 0.05, 100, 100, 2000)
18+
self.equation = bse.BlackScholesEquation(enum.OptionType.EUROPEAN_CALL, 300, 1, 0.2, 0.05, 100, 100, 2000)
1219

1320
# explicit method tests
1421

@@ -26,7 +33,7 @@ def test_check_terminal_condition_for_call_explicit(self):
2633
assert np.array_equal(result[:, -1], expected_payoff)
2734

2835
def test_check_terminal_condition_for_put_explicit(self):
29-
self.equation.option_type = 'put'
36+
self.equation.option_type = enum.OptionType.EUROPEAN_PUT
3037
result = solver.BlackScholesExplicitSolver(self.equation).solve().get_result()
3138

3239
test_asset_grid = self.equation.generate_grid(self.equation.S_max, self.equation.s_nodes)
@@ -36,7 +43,7 @@ def test_check_terminal_condition_for_put_explicit(self):
3643
assert np.array_equal(result[:,-1], expected_payoff)
3744

3845
def test_check_valid_option_type(self):
39-
self.equation.option_type = 'woo'
46+
self.equation.option_type = "INVALID"
4047

4148
with pytest.raises(ValueError, match="Invalid option type - please choose between call/put"):
4249
solver.BlackScholesExplicitSolver(self.equation).solve().get_result()
@@ -57,7 +64,7 @@ def test_check_terminal_condition_for_call_cn(self):
5764
assert np.array_equal(result[:, -1], expected_payoff)
5865

5966
def test_check_terminal_condition_for_put_cn(self):
60-
self.equation.option_type = 'put'
67+
self.equation.option_type = enum.OptionType.EUROPEAN_PUT
6168
result = solver.BlackScholesCNSolver(self.equation).solve().get_result()
6269

6370
test_asset_grid = self.equation.generate_grid(self.equation.S_max, self.equation.s_nodes)

0 commit comments

Comments
 (0)