Skip to content

Commit

Permalink
MAINT: Clean up code in calculations.py module
Browse files Browse the repository at this point in the history
* Restructures all calculation functions so
the symbolic output is first in the code path

* Changes all `output_strings`/`key` mismatch
errors to use the same error message

* Various variable name changes to make
code more readable
  • Loading branch information
nawtrey committed Aug 16, 2024
1 parent 2a7cc47 commit 48b3e16
Showing 1 changed file with 125 additions and 163 deletions.
288 changes: 125 additions & 163 deletions kda/calculations.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,10 @@ def calc_sigma(G, dirpar_edges, key="name", output_strings=True):
Returns
-------
sigma : float
Normalization factor for state probabilities.
sigma_str : str
Sum of rate products of all directional diagrams for input
diagram ``G``, in string form.
sigma : float or str
Sum of rate products of all directional diagrams for the input
diagram ``G`` as a float (``output_strings=False``) for a string
(``output_strings=True``).
Notes
-----
Expand Down Expand Up @@ -160,41 +159,33 @@ def calc_sigma(G, dirpar_edges, key="name", output_strings=True):
of all directional diagrams for the kinetic diagram.
"""
# Number of nodes/states
n_states = G.number_of_nodes()
n_dirpars = dirpar_edges.shape[0]
edge_value = G.edges[list(G.edges)[0]][key]
edge_is_str = isinstance(G.edges[list(G.edges)[0]][key], str)
if output_strings != edge_is_str:
msg = f"""Inputs `key={key}` and `output_strings={output_strings}`
do not match. If symbolic outputs are requested the input `key`
should retrieve edge data from `G` that corresponds to symbolic
variable names for all edges."""
raise TypeError(msg)

if not output_strings:
if isinstance(edge_value, str):
raise TypeError(
"To enter variable strings set parameter output_strings=True."
)
dirpar_rate_products = np.ones(n_dirpars, dtype=float)
n_dir_diagrams = dirpar_edges.shape[0]
if output_strings:
rate_products = np.empty(shape=(n_dir_diagrams,), dtype=object)
# iterate over the directional diagrams
for i, edge_list in enumerate(dirpar_edges):
# iterate over the edges in the given directional diagram i
for edge in edge_list:
# multiply the rate of each edge in edge_list
dirpar_rate_products[i] *= G.edges[edge][key]
sigma = math.fsum(dirpar_rate_products)
return sigma
elif output_strings:
if not isinstance(edge_value, str):
raise TypeError(
"To enter variable values set parameter output_strings=False."
)
dirpar_rate_products = np.empty(shape=(n_dirpars,), dtype=object)
rates = [G.edges[edge][key] for edge in edge_list]
rate_products[i] = "*".join(rates)
# sum all terms to get normalization factor
sigma = "+".join(rate_products)
else:
rate_products = np.ones(n_dir_diagrams, dtype=float)
# iterate over the directional diagrams
for i, edge_list in enumerate(dirpar_edges):
rate_product_vals = []
# iterate over the edges in the given directional diagram i
for edge in edge_list:
# append rate constant names from dir_par to list
rate_product_vals.append(G.edges[edge][key])
dirpar_rate_products[i] = "*".join(rate_product_vals)
# sum all terms to get normalization factor
sigma_str = "+".join(dirpar_rate_products)
return sigma_str
# multiply the rate of each edge in edge_list
rate_products[i] *= G.edges[edge][key]
sigma = math.fsum(rate_products)
return sigma


def calc_sigma_K(G, cycle, flux_diags, key="name", output_strings=True):
Expand Down Expand Up @@ -228,12 +219,10 @@ def calc_sigma_K(G, cycle, flux_diags, key="name", output_strings=True):
Returns
-------
sigma_K : float
Sum of rate products of directional flux diagram edges pointing to
input cycle.
sigma_K_str : str
sigma_K : float or str
Sum of rate products of directional flux diagram edges pointing to
input cycle in string form.
input cycle as a float (``output_strings=False``) or as a string
(``output_strings=True``).
Notes
-----
Expand All @@ -252,55 +241,45 @@ def calc_sigma_K(G, cycle, flux_diags, key="name", output_strings=True):
sum. For cycles with no flux diagrams, :math:`\Sigma_{k} = 1`.
"""
if isinstance(flux_diags, list) == False:
print(
"No flux diagrams detected for cycle {}. Sigma K value is 1.".format(cycle)
)
if not isinstance(flux_diags, list):
print(f"No flux diagrams detected for cycle {cycle}. Sigma K value is 1.")
return 1
edge_is_str = isinstance(G.edges[list(G.edges)[0]][key], str)
if output_strings != edge_is_str:
msg = f"""Inputs `key={key}` and `output_strings={output_strings}`
do not match. If symbolic outputs are requested the input `key`
should retrieve edge data from `G` that corresponds to symbolic
variable names for all edges."""
raise TypeError(msg)

# check that the input cycle is in the correct order
ordered_cycle = _get_ordered_cycle(G, cycle)
cycle_edges = diagrams._construct_cycle_edges(ordered_cycle)
if output_strings:
rate_products = []
for diagram in flux_diags:
diag = diagram.copy()
for edge in cycle_edges:
diag.remove_edge(edge[0], edge[1], edge[2])
diag.remove_edge(edge[1], edge[0], edge[2])
rates = []
for edge in diag.edges:
rates.append(G.edges[edge[0], edge[1], edge[2]][key])
rate_products.append("*".join(rates))
sigma_K = "+".join(rate_products)
else:
# check that the input cycle is in the correct order
ordered_cycle = _get_ordered_cycle(G, cycle)
cycle_edges = diagrams._construct_cycle_edges(ordered_cycle)
if output_strings == False:
if isinstance(
G.edges[cycle_edges[0][0], cycle_edges[0][1], cycle_edges[0][2]][key],
str,
):
raise TypeError(
"To enter variable strings set parameter output_strings=True."
)
rate_products = []
for diagram in flux_diags:
diag = diagram.copy()
for edge in cycle_edges:
diag.remove_edge(edge[0], edge[1], edge[2])
diag.remove_edge(edge[1], edge[0], edge[2])
vals = 1
for edge in diag.edges:
vals *= G.edges[edge[0], edge[1], edge[2]][key]
rate_products.append(vals)
sigma_K = math.fsum(rate_products)
return sigma_K
elif output_strings == True:
if not isinstance(
G.edges[cycle_edges[0][0], cycle_edges[0][1], cycle_edges[0][2]][key],
str,
):
raise TypeError(
"To enter variable values set parameter output_strings=False."
)
rate_products = []
for diagram in flux_diags:
diag = diagram.copy()
for edge in cycle_edges:
diag.remove_edge(edge[0], edge[1], edge[2])
diag.remove_edge(edge[1], edge[0], edge[2])
rates = []
for edge in diag.edges:
rates.append(G.edges[edge[0], edge[1], edge[2]][key])
rate_products.append("*".join(rates))
sigma_K_str = "+".join(rate_products)
return sigma_K_str
rate_products = []
for diagram in flux_diags:
diag = diagram.copy()
for edge in cycle_edges:
diag.remove_edge(edge[0], edge[1], edge[2])
diag.remove_edge(edge[1], edge[0], edge[2])
vals = 1
for edge in diag.edges:
vals *= G.edges[edge[0], edge[1], edge[2]][key]
rate_products.append(vals)
sigma_K = math.fsum(rate_products)
return sigma_K


def calc_pi_difference(G, cycle, order, key="name", output_strings=True):
Expand Down Expand Up @@ -336,12 +315,10 @@ def calc_pi_difference(G, cycle, order, key="name", output_strings=True):
Returns
-------
pi_diff : float
Difference of product of counter clockwise cycle rates and clockwise
cycle rates.
pi_diff_str : str
String of difference of product of counter clockwise cycle rates and
clockwise cycle rates.
pi_difference : float or str
Difference of the counter-clockwise and clockwise cycle
rate-products as a float (``output_strings=False``)
or a string (``output_strings=True``).
Notes
-----
Expand All @@ -362,38 +339,33 @@ def calc_pi_difference(G, cycle, order, key="name", output_strings=True):
the forward (i.e. positive) direction is counter-clockwise (CCW).
"""
edge_is_str = isinstance(G.edges[list(G.edges)[0]][key], str)
if output_strings != edge_is_str:
msg = f"""Inputs `key={key}` and `output_strings={output_strings}`
do not match. If symbolic outputs are requested the input `key`
should retrieve edge data from `G` that corresponds to symbolic
variable names for all edges."""
raise TypeError(msg)

# check that the input cycle is in the correct order
ordered_cycle = _get_ordered_cycle(G, cycle)
CCW_cycle = graph_utils.get_ccw_cycle(ordered_cycle, order)
cycle_edges = diagrams._construct_cycle_edges(CCW_cycle)
if output_strings == False:
if isinstance(
G.edges[cycle_edges[0][0], cycle_edges[0][1], cycle_edges[0][2]][key], str
):
raise TypeError(
"To enter variable strings set parameter output_strings=True."
)
ccw_rates = 1
cw_rates = 1
for edge in cycle_edges:
ccw_rates *= G.edges[edge[0], edge[1], edge[2]][key]
cw_rates *= G.edges[edge[1], edge[0], edge[2]][key]
pi_difference = ccw_rates - cw_rates
return pi_difference
elif output_strings == True:
if not isinstance(
G.edges[cycle_edges[0][0], cycle_edges[0][1], cycle_edges[0][2]][key], str
):
raise TypeError(
"To enter variable values set parameter output_strings=False."
)
if output_strings:
ccw_rates = []
cw_rates = []
for edge in cycle_edges:
ccw_rates.append(G.edges[edge[0], edge[1], edge[2]][key])
cw_rates.append(G.edges[edge[1], edge[0], edge[2]][key])
pi_difference = "-".join(["*".join(ccw_rates), "*".join(cw_rates)])
return pi_difference
else:
ccw_rates = 1
cw_rates = 1
for edge in cycle_edges:
ccw_rates *= G.edges[edge[0], edge[1], edge[2]][key]
cw_rates *= G.edges[edge[1], edge[0], edge[2]][key]
pi_difference = ccw_rates - cw_rates
return pi_difference


def calc_thermo_force(G, cycle, order, key="name", output_strings=True):
Expand Down Expand Up @@ -427,12 +399,12 @@ def calc_thermo_force(G, cycle, order, key="name", output_strings=True):
Returns
-------
thermo_force : float
The calculated thermodynamic force for the input cycle. This value is
unitless and should be multiplied by ``kT``.
parsed_thermo_force_str : ``SymPy`` expression
Symbolic thermodynamic force expression. Should be
multiplied by ``kT`` to get actual thermodynamic force.
thermo_force : float or ``SymPy`` expression
The thermodynamic force for the input ``cycle`` returned
as a float (``output_strings=False``) or a ``SymPy`` expression
(``output_strings=True`). The returned value is unitless and
should be multiplied by ``kT`` to calculate the actual
thermodynamic force.
Notes
-----
Expand All @@ -452,41 +424,36 @@ def calc_thermo_force(G, cycle, order, key="name", output_strings=True):
(i.e. :math:`\chi_{k} = 0`).
"""
edge_is_str = isinstance(G.edges[list(G.edges)[0]][key], str)
if output_strings != edge_is_str:
msg = f"""Inputs `key={key}` and `output_strings={output_strings}`
do not match. If symbolic outputs are requested the input `key`
should retrieve edge data from `G` that corresponds to symbolic
variable names for all edges."""
raise TypeError(msg)

# check that the input cycle is in the correct order
ordered_cycle = _get_ordered_cycle(G, cycle)
CCW_cycle = graph_utils.get_ccw_cycle(ordered_cycle, order)
cycle_edges = diagrams._construct_cycle_edges(CCW_cycle)
if output_strings == False:
if isinstance(
G.edges[cycle_edges[0][0], cycle_edges[0][1], cycle_edges[0][2]][key], str
):
raise TypeError(
"To enter variable strings set parameter output_strings=True."
)
ccw_rates = 1
cw_rates = 1
for edge in cycle_edges:
ccw_rates *= G.edges[edge[0], edge[1], edge[2]][key]
cw_rates *= G.edges[edge[1], edge[0], edge[2]][key]
thermo_force = np.log(ccw_rates / cw_rates)
return thermo_force
elif output_strings == True:
if not isinstance(
G.edges[cycle_edges[0][0], cycle_edges[0][1], cycle_edges[0][2]][key], str
):
raise TypeError(
"To enter variable values set parameter output_strings=False."
)
if output_strings:
ccw_rates = []
cw_rates = []
for edge in cycle_edges:
ccw_rates.append(G.edges[edge[0], edge[1], edge[2]][key])
cw_rates.append(G.edges[edge[1], edge[0], edge[2]][key])
thermo_force_str = (
thermo_force = (
"ln(" + "*".join(ccw_rates) + ") - ln(" + "*".join(cw_rates) + ")"
)
parsed_thermo_force_str = logcombine(parse_expr(thermo_force_str), force=True)
return parsed_thermo_force_str
thermo_force = logcombine(parse_expr(thermo_force), force=True)
else:
ccw_rates = 1
cw_rates = 1
for edge in cycle_edges:
ccw_rates *= G.edges[edge[0], edge[1], edge[2]][key]
cw_rates *= G.edges[edge[1], edge[0], edge[2]][key]
thermo_force = np.log(ccw_rates / cw_rates)
return thermo_force


def calc_state_probs(G, key="name", output_strings=True, dir_edges=None):
Expand Down Expand Up @@ -561,43 +528,38 @@ def calc_state_probs(G, key="name", output_strings=True, dir_edges=None):
# get the number of nodes/states
n_states = G.number_of_nodes()
# get the number of directional diagrams
n_dirpars = dir_edges.shape[0]
n_dir_diagrams = dir_edges.shape[0]
# get the number of partial diagrams
n_partials = int(n_dirpars / n_states)
n_partials = int(n_dir_diagrams / n_states)
if output_strings:
dirpar_rate_products = np.empty(shape=(n_dirpars,), dtype=object)
rate_products = np.empty(shape=(n_dir_diagrams,), dtype=object)
for i, edge_list in enumerate(dir_edges):
rate_product_vals = []
for edge in edge_list:
rate_product_vals.append(G.edges[edge][key])
dirpar_rate_products[i] = "*".join(rate_product_vals)

rates = [G.edges[edge][key] for edge in edge_list]
rate_products[i] = "*".join(rates)
rate_products = rate_products.reshape(n_states, n_partials)
state_mults = np.empty(shape=(n_states,), dtype=object)
dirpar_rate_products = dirpar_rate_products.reshape(n_states, n_partials)
for i, arr in enumerate(dirpar_rate_products):
for i, arr in enumerate(rate_products):
state_mults[i] = "+".join(arr)
state_probs = expressions.construct_sympy_prob_funcs(state_mult_funcs=state_mults)
else:
# retrieve the rate matrix from G
Kij = graph_utils.retrieve_rate_matrix(G)
# create array of ones for storing rate products
dirpar_rate_products = np.ones(n_dirpars, dtype=float)
rate_products = np.ones(n_dir_diagrams, dtype=float)
# iterate over the directional diagrams
for i, edge_list in enumerate(dir_edges):
# for each edge list, retrieve an array of the ith and jth indices,
# retrieve the values associated with each (i, j) pair, and
# calculate the product of those values
Ki = edge_list[:, 0]
Kj = edge_list[:, 1]
dirpar_rate_products[i] = np.prod(Kij[Ki, Kj])

state_mults = dirpar_rate_products.reshape(n_states, n_partials).sum(axis=1)
state_probs = state_mults / math.fsum(dirpar_rate_products)
rate_products[i] = np.prod(Kij[Ki, Kj])
state_mults = rate_products.reshape(n_states, n_partials).sum(axis=1)
state_probs = state_mults / math.fsum(rate_products)
if any(elem < 0 for elem in state_probs):
raise ValueError(
"Calculated negative state probabilities, overflow or underflow occurred."
)

msg = """Calculated negative state probabilities,
overflow or underflow occurred."""
raise ValueError(msg)
return state_probs


Expand Down

0 comments on commit 48b3e16

Please sign in to comment.