Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

working conda env #3

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: slfrank
channels:
- frankong
- main
- conda-forge
- defaults
dependencies:
- cvxpy==1.2.1
- matplotlib==3.5.3
- numpy==1.19.5
- python==3.8.16
- scipy==1.9.1
- sigpy==0.1.23
- pandas==1.4.4
Empty file added slr_creation/__init__.py
Empty file.
93 changes: 93 additions & 0 deletions slr_creation/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
script for using slfrank to create slr pulse with set parameters and export with rfpf to interface
eg. in jstmc or emc
_________________
Jochen Schmidt
15.02.2023
"""

import numpy as np
import matplotlib.pyplot as plt
import slfrank
import pathlib as plib
from rf_pulse_files import rfpf
from slr_creation import options

import logging


def main(opts: options.Config):
opts.display()
gamma_Hz = 42577478.518
bandwidth = gamma_Hz * opts.desiredSliceGrad * 1e-3 * opts.desiredSliceThickness * 1e-3
if opts.desiredDuration > 0:
opts.timeBandwidth = bandwidth * opts.desiredDuration * 1e-6
else:
opts.desiredDuration = int(opts.timeBandwidth / bandwidth * 1e6)

opts.display()
# init object
slr_pulse = rfpf.RF(
name="slfrank_lin_phase_refocus",
duration_in_us=opts.desiredDuration,
bandwidth_in_Hz=bandwidth,
time_bandwidth=opts.timeBandwidth,
num_samples=opts.numSamples
)

# set solver
solver = 'PDHG'
pulse_type = opts.pulseType
phase_type = opts.phaseType

logging.info(f"Generating pulse"
f"\t\t__type: {pulse_type} \t __phase type {phase_type}")

# getting length n complex array
pulse_slfrank = slfrank.design_rf(
n=opts.numSamples, tb=opts.timeBandwidth,
ptype=pulse_type, phase=phase_type,
d1=opts.rippleSizes, d2=opts.rippleSizes,
solver=solver, max_iter=opts.maxIter)

slr_pulse.amplitude = np.real(pulse_slfrank)
slr_pulse.phase = np.angle(pulse_slfrank)

logging.info(f'SLfRank:\tEnergy={np.sum(np.abs(pulse_slfrank)**2)}\tPeak={np.abs(pulse_slfrank).max()}')

logging.info("plotting")
fig = slfrank.plot_slr_pulses(
np.full_like(pulse_slfrank, np.nan, dtype=complex),
pulse_slfrank, ptype=pulse_type, phase=phase_type,
omega_range=[-1, 1], tb=opts.timeBandwidth, d1=opts.rippleSizes, d2=opts.rippleSizes)
plt.tight_layout()
plt.show()

out_path = plib.Path(opts.outputPath).absolute()
plot_file = out_path.joinpath(f'{pulse_type}_{phase_type}.png')
pulse_file = out_path.joinpath(f"slfrank_{pulse_type}_pulse_{phase_type}_phase.pkl")

logging.info(f"saving plot {plot_file}")
fig.savefig(plot_file, bbox_inches="tight", transparent=True)

logging.info(f"saving pulse {pulse_file}")
slr_pulse.save(pulse_file)

logging.info(f"finished")


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s %(levelname)s :: %(name)s -- %(message)s',
datefmt='%I:%M:%S', level=logging.INFO)

# get cmd line input
parser, args = options.createCommandlineParser()

logging.info("set parameters")
conf_opts = options.Config.from_cmd_args(args)

try:
main(conf_opts)
except Exception as e:
logging.error(e)
parser.print_usage()
1 change: 1 addition & 0 deletions slr_creation/default_conf.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"configFile": "", "outputPath": "./pulses/", "numSamples": 300, "timeBandwidth": 2.75, "desiredDuration": 0, "desiredSliceGrad": 35.0, "desiredSliceThickness": 0.7, "rippleSizes": 0.01, "pulseType": "ex", "phaseType": "linear", "maxIter": 2500}
65 changes: 65 additions & 0 deletions slr_creation/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import simple_parsing as sp
import dataclasses as dc
import pathlib as plib
import logging

logModule = logging.getLogger(__name__)


@dc.dataclass
class Config(sp.helpers.Serializable):
configFile: str = sp.field(default="", alias=["-c"])
outputPath: str = sp.field(default="./pulses/", alias=["-o"])

numSamples: int = sp.field(default=300, alias=["-n"])
timeBandwidth: float = sp.field(default=2.75, alias=["-tb"])
desiredDuration: int = sp.field(default=0, alias=["-dur"]) # us
desiredSliceGrad: float = sp.field(default=35.0, alias=["-sg"]) # [mT/m]
desiredSliceThickness: float = sp.field(default=0.7, alias=["-st"]) # [mm]
rippleSizes: float = sp.field(default=0.01, alias=["-r"])
pulseType: str = sp.choice("ex", "se", "inv", default="ex", alias=["-pu"])
phaseType: str = sp.choice("linear", "minimum", default="linear", alias=["-ph"])
maxIter: int = sp.field(2500, alias=["-mi"])

@classmethod
def from_cmd_args(cls, args: sp.ArgumentParser.parse_args):
# create default_dict
default_instance = cls()
instance = cls()
if args.config.configFile:
confPath = plib.Path(args.config.configFile).absolute()
instance = cls.load(confPath)
# might contain defaults
for key, item in default_instance.__dict__.items():
parsed_arg = args.config.__dict__.get(key)
# if parsed arguments are not defaults
if parsed_arg != default_instance.__dict__.get(key):
# update instance even if changed by config file -> that way prioritize cmd line input
instance.__setattr__(key, parsed_arg)
return instance

def display(self):
logModule.info(f"Parameters:\n")
for key, value in self.__dict__.items():
logModule.info(f"\t\t{key}: \t\t {value}")


def createCommandlineParser():
"""
Build the parser for arguments
Parse the input arguments.
"""
parser = sp.ArgumentParser(prog='slr_slfrank')
parser.add_arguments(Config, dest="config")
args = parser.parse_args()

return parser, args


if __name__ == '__main__':
save_path = plib.Path("default_conf.json").absolute()
save_path.parent.mkdir(parents=True, exist_ok=True)
conf = Config()
conf.display()
conf.save(save_path)