Skip to content

Commit

Permalink
made rho_r_array a class variable, added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wrguenthner committed Apr 18, 2024
1 parent cd6f7d8 commit 8f98263
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 56 deletions.
17 changes: 14 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,26 @@ permissions:

jobs:
build:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version:
- "3.7"
- "3.8"
- "3.9"
- "3.10"
- "3.11"
- "3.12"

runs-on: ubuntu-latest
runs-on: ${{ matrix.os }}

steps:
- uses: actions/checkout@v3
- name: Set up Python 3.10
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: "3.10"
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
56 changes: 24 additions & 32 deletions src/pythermo/crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def poly_interp(self, xa, ya, jl, x=0):
return y

class zircon(crystal):
def __init__(self, radius, log2_nodes, relevant_tT, U_ppm, Th_ppm, Sm_ppm=0):
def __init__(self, radius, log2_nodes, relevant_tT, rho_r_array, U_ppm, Th_ppm, Sm_ppm=0):
"""
Constructor for zircon class, a sub class of the crystal super class.
Expand All @@ -533,6 +533,9 @@ def __init__(self, radius, log2_nodes, relevant_tT, U_ppm, Th_ppm, Sm_ppm=0):
relevant_tT: 2D array of floats
Discretized time-temperature path
rho_r_array: 2D array of floats
Matrix of dimensionless track density for the time-temperature path relevant_tT. Assumes that the # of rows in rho_r_array and relevant_tT are equivalent
U_ppm: float
Concentration of uranium in zircon (in ppm)
Expand All @@ -554,6 +557,7 @@ def __init__(self, radius, log2_nodes, relevant_tT, U_ppm, Th_ppm, Sm_ppm=0):
self.__Th_ppm = Th_ppm
self.__Sm_ppm = Sm_ppm
self.__relevant_tT = relevant_tT
self.__rho_r_array = rho_r_array

def get_radius(self):
return self.__radius
Expand All @@ -578,17 +582,14 @@ def get_Sm_ppm(self):

def get_relevant_tT(self):
return self.__relevant_tT

def get_rho_r_array(self):
return self.__rho_r_array

def guenthner_damage(self,rho_r_array):
def guenthner_damage(self):
"""
Calculates the amount of radiation damage at each time step of class variable relevant_tT using the parameterization of Guenthner et al. (2013) (https://doi.org/10.2475/03.2013.01). The damage amounts can then be directly used to calculate diffusivities at each time step.
Parameters
----------
rho_r_array: 2D array of floats
Matrix of dimensionless track density for the time-temperature path relevant_tT. Assumes that the # of rows in rho_r_array and relevant_tT are equivalent
Returns
-------
Expand All @@ -600,6 +601,7 @@ def guenthner_damage(self,rho_r_array):
U235_atom = self.__U_ppm * U235_ppm_atom
Th_atom = self.__Th_ppm * Th_ppm_atom
relevant_tT = self.__relevant_tT
rho_r_array = self.__rho_r_array

#calculate dose in alpha/g at each time step
alpha_i = [8*U238_atom*(np.exp(lambda_238*relevant_tT[i,0])-np.exp(lambda_238*relevant_tT[i+1,0])) + 7*U235_atom*(np.exp(lambda_235*relevant_tT[i,0])-np.exp(lambda_235*relevant_tT[i+1,0])) + 6*Th_atom*(np.exp(lambda_232*relevant_tT[i,0])-np.exp(lambda_232*relevant_tT[i+1,0])) for i in range(np.size(relevant_tT,0)-2,-1,-1)]
Expand All @@ -612,16 +614,10 @@ def guenthner_damage(self,rho_r_array):

return damage

def guenthner_date(self,rho_r_array):
def guenthner_date(self):
"""
Zircon (U-Th)/He date calculator. First, calculates the diffusivity at each time step of class variable relevant_tT using the parameterization of Guenthner et al. (2013) (https://doi.org/10.2475/03.2013.01). The diffusivities are then passed to the parent class method CN_diffusion, along with relevant parameters. Finally, the parent class method He_date is called to convert the He profile to a (U-Th)/He date.
Parameters
----------
rho_r_array: 2D array of floats
Matrix of dimensionless track density for the time-temperature path relevant_tT. Assumes that the # of rows in rho_r_array and relevant_tT are equivalent
Returns
-------
Expand All @@ -635,7 +631,7 @@ def guenthner_date(self,rho_r_array):
relevant_tT = self.__relevant_tT

#calculate damage levels using guenthner_damage function
damage = self.guenthner_damage(rho_r_array)
damage = self.guenthner_damage()

#Guenthner et al. (2013) diffusion equation parameters, Eas are in kJ/mol, D0s converted to microns2/s
Ea_l = 165.0
Expand Down Expand Up @@ -695,7 +691,7 @@ def guenthner_date(self,rho_r_array):
return date

class apatite(crystal):
def __init__(self, radius, log2_nodes, relevant_tT, U_ppm, Th_ppm, Sm_ppm=0):
def __init__(self, radius, log2_nodes, relevant_tT, rho_r_array, U_ppm, Th_ppm, Sm_ppm=0):
"""
Constructor for apatite class, a sub class of the crystal super class.
Expand All @@ -710,6 +706,9 @@ def __init__(self, radius, log2_nodes, relevant_tT, U_ppm, Th_ppm, Sm_ppm=0):
relevant_tT: 2D array of floats
Discretized time-temperature path
rho_r_array: 2D array of floats
Matrix of dimensionless track density for the time-temperature path relevant_tT. Assumes that the # of rows in rho_r_array and relevant_tT are equivalent
U_ppm: float
Concentration of uranium in apatite (in ppm)
Expand All @@ -731,6 +730,7 @@ def __init__(self, radius, log2_nodes, relevant_tT, U_ppm, Th_ppm, Sm_ppm=0):
self.__Th_ppm = Th_ppm
self.__Sm_ppm = Sm_ppm
self.__relevant_tT = relevant_tT
self.__rho_r_array = rho_r_array

def get_radius(self):
return self.__radius
Expand All @@ -755,16 +755,13 @@ def get_Sm_ppm(self):

def get_relevant_tT(self):
return self.__relevant_tT

def get_rho_r_array(self):
return self.__rho_r_array

def flowers_damage(self,rho_r_array):
def flowers_damage(self):
"""
Calculates the amount of radiation damage at each time step of class variable relevant_tT using the parameterization of Flowers et al. (2009) (https://doi.org/10.1016/j.gca.2009.01.015). The damage amounts can then be directly used to calculate diffusivities at each time step.
Parameters
----------
rho_r_array: 2D array of floats
Matrix of dimensionless track density for the time-temperature path relevant_tT. Assumes that the # of rows in rho_r_array and relevant_tT are equivalent
Returns
-------
Expand All @@ -784,6 +781,7 @@ def flowers_damage(self,rho_r_array):
Sm_vol = self.__Sm_ppm * Sm_ppm_atom * ap_density

relevant_tT = self.__relevant_tT
rho_r_array = self.__rho_r_array

#rho v calculation, atoms/cc
rho_v = [U238_vol*(np.exp(lambda_238*relevant_tT[i,0])-np.exp(lambda_238*relevant_tT[i+1,0])) + (7/8)*U235_vol*(np.exp(lambda_235*relevant_tT[i,0])-np.exp(lambda_235*relevant_tT[i+1,0])) + (6/8)*Th_vol*(np.exp(lambda_232*relevant_tT[i,0])-np.exp(lambda_232*relevant_tT[i+1,0])) + (1/8)*Sm_vol*(np.exp(lambda_147*relevant_tT[i,0])-np.exp(lambda_147*relevant_tT[i+1,0])) for i in range(np.size(relevant_tT,0)-2,-1,-1)]
Expand All @@ -796,15 +794,9 @@ def flowers_damage(self,rho_r_array):

return damage

def flowers_date(self,rho_r_array):
def flowers_date(self):
"""
Apatite (U-Th)/He date calculator. First, calculates the diffusivity at each time step of class variable relevant_tT using the parameterization of Flowers et al. (2009) (https://doi.org/10.1016/j.gca.2009.01.015). The diffusivities are then passed to the parent class method CN_diffusion, along with relevant parameters. Finally,the parent class method He_date is called to convert the He profile to a (U-Th)/He date.
Parameters
----------
rho_r_array: 2D array of floats
Matrix of dimensionless track density for the time-temperature path relevant_tT. Assumes that the # of rows in rho_r_array and relevant_tT are equivalent
Returns
-------
Expand All @@ -814,7 +806,7 @@ def flowers_date(self,rho_r_array):
"""
#calculate damage levels using flowers_damage function
damage = self.flowers_damage(rho_r_array)
damage = self.flowers_damage()

#Flowers et al. 2009 damage-diffusivity equation parameters
omega = 10**-22
Expand Down
28 changes: 14 additions & 14 deletions src/pythermo/tT_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,14 @@ def forward(self, comp_type='size', model_num=1, std_grain=0, log2_nodes=8):

if grains.iloc[j,0] == 'apatite':
if max_size > mean_size:
mean_grain = apatite(mean_size,log2_nodes,ap_tT,U_ppm,Th_ppm,Sm_ppm)
std_plus_grain = apatite(max_size,log2_nodes,ap_tT,U_ppm,Th_ppm,Sm_ppm)
std_minus_grain = apatite(min_size,log2_nodes,ap_tT,U_ppm,Th_ppm,Sm_ppm)
mean_grain = apatite(mean_size,log2_nodes,ap_tT,ap_anneal,U_ppm,Th_ppm,Sm_ppm)
std_plus_grain = apatite(max_size,log2_nodes,ap_tT,ap_anneal,U_ppm,Th_ppm,Sm_ppm)
std_minus_grain = apatite(min_size,log2_nodes,ap_tT,ap_anneal,U_ppm,Th_ppm,Sm_ppm)

#calculate dates
mean_date = mean_grain.flowers_date(ap_anneal)
std_plus_date = std_plus_grain.flowers_date(ap_anneal)
std_minus_date = std_minus_grain.flowers_date(ap_anneal)
mean_date = mean_grain.flowers_date()
std_plus_date = std_plus_grain.flowers_date()
std_minus_date = std_minus_grain.flowers_date()

#add dates to array
model_data[j,i*3//2] = mean_date
Expand All @@ -151,10 +151,10 @@ def forward(self, comp_type='size', model_num=1, std_grain=0, log2_nodes=8):
model_data[j+2*num_grains,i*3//2+2] = min_size

else:
mean_grain = apatite(mean_size,log2_nodes,ap_tT,U_ppm,Th_ppm,Sm_ppm)
mean_grain = apatite(mean_size,log2_nodes,ap_tT,ap_anneal,U_ppm,Th_ppm,Sm_ppm)

#calculate date
mean_date = mean_grain.flowers_date(ap_anneal)
mean_date = mean_grain.flowers_date()

#add date to array
model_data[j,i*3//2] = mean_date
Expand All @@ -167,14 +167,14 @@ def forward(self, comp_type='size', model_num=1, std_grain=0, log2_nodes=8):

elif grains.iloc[j,0] == 'zircon':
if max_size > mean_size:
mean_grain = zircon(mean_size,log2_nodes,zirc_tT,U_ppm,Th_ppm,Sm_ppm)
mean_grain = zircon(mean_size,log2_nodes,zirc_tT,zirc_anneal,U_ppm,Th_ppm,Sm_ppm)
std_plus_grain = zircon(max_size,log2_nodes,zirc_tT,U_ppm,Th_ppm,Sm_ppm)
std_minus_grain = zircon(min_size,log2_nodes,zirc_tT,U_ppm,Th_ppm,Sm_ppm)

#calculate dates
mean_date = mean_grain.guenthner_date(zirc_anneal)
std_plus_date = std_plus_grain.guenthner_date(zirc_anneal)
std_minus_date = std_minus_grain.guenthner_date(zirc_anneal)
mean_date = mean_grain.guenthner_date()
std_plus_date = std_plus_grain.guenthner_date()
std_minus_date = std_minus_grain.guenthner_date()

#add dates to array
model_data[j,i*3//2] = mean_date
Expand All @@ -192,10 +192,10 @@ def forward(self, comp_type='size', model_num=1, std_grain=0, log2_nodes=8):
model_data[j+2*num_grains,i*3//2+2] = min_size

else:
mean_grain = zircon(mean_size,log2_nodes,zirc_tT,U_ppm,Th_ppm,Sm_ppm)
mean_grain = zircon(mean_size,log2_nodes,zirc_tT,zirc_anneal,U_ppm,Th_ppm,Sm_ppm)

#calculate date
mean_date = mean_grain.guenthner_date(zirc_anneal)
mean_date = mean_grain.guenthner_date()

#add date to array
model_data[j,i*3//2] = mean_date
Expand Down
26 changes: 19 additions & 7 deletions tests/test_crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ def ap():
tT = pyt.tT_path(tT_in)
tT.tT_interpolate()
ap_anneal,ap_tT = tT.ketcham_anneal()
return pyt.apatite(60,8,ap_tT,100,50,Sm_ppm=10)
return pyt.apatite(60,8,ap_tT,ap_anneal,100,50,Sm_ppm=10)

@pytest.fixture()
def zirc():
tT_in = np.array([[0,20],[250,100],[500,20]])
tT_in = np.array([[0,20],[250,200],[500,20]])
tT = pyt.tT_path(tT_in)
tT.tT_interpolate()
zirc_anneal,zirc_tT = tT.ketcham_anneal()
return pyt.zircon(60,8,zirc_tT,1000,500)
return pyt.zircon(60,8,zirc_tT,zirc_anneal,1000,500)

def test_alpha_ejection():
pass
Expand All @@ -28,13 +28,25 @@ def test_He_date():
pass

def test_guenthner_damage(zirc):
pass
damage = zirc.guenthner_damage()
relevant_tT = zirc.get_relevant_tT()
assert np.size(damage) == np.size(relevant_tT,0)
assert np.isnan(damage).any() == False
assert np.any(damage < 0) == False

def test_guenthner_date(zirc):
pass
zirc_date = zirc.guenthner_date()
assert zirc_date > 0
assert zirc_date < 500

def test_flowers_damage(ap):
pass
damage = ap.guenthner_damage()
relevant_tT = ap.get_relevant_tT()
assert np.size(damage) == np.size(relevant_tT,0)
assert np.isnan(damage).any() == False
assert np.any(damage < 0) == False

def test_flowers_date(ap):
pass
ap_date = ap.flowers_date()
assert ap_date > 0
assert ap_date < 500

0 comments on commit 8f98263

Please sign in to comment.