diff --git a/mushroom/core/crystutils.py b/mushroom/core/crystutils.py index 3e02069..7968160 100644 --- a/mushroom/core/crystutils.py +++ b/mushroom/core/crystutils.py @@ -330,28 +330,28 @@ def display_symmetry_info(latt, posi, atms, n_sym_cols: int = 4): atms_uniq = list(set(atms)) atms_spglib = [atms_uniq.index(x) for x in atms] ds = spglib.get_symmetry_dataset((latt, posi, atms_spglib)) - spg_number = ds["number"] + spg_number = ds.number if spg_number in SPGNUMBER2NAME: - info_spacegroup = "{} (#{}, {})".format(ds["international"], spg_number, SPGNUMBER2NAME[spg_number]) + info_spacegroup = "{} (#{}, {})".format(ds.international, spg_number, SPGNUMBER2NAME[spg_number]) else: - info_spacegroup = "{} (#{})".format(ds["international"], spg_number) + info_spacegroup = "{} (#{})".format(ds.international, spg_number) print("Space group:", info_spacegroup) - print("Point group:", ds["pointgroup"]) - print("Hall: {} (#{})".format(ds["hall"], ds["hall_number"])) + print("Point group:", ds.pointgroup) + print("Hall: {} (#{})".format(ds.hall, ds.hall_number)) print("Symmetry operations") fmtstr = "{:2s} {:2d} {:2d} {:2d} | {:.2f}" rowsep = "-" * 18 rowsep = (rowsep + " ") * (n_sym_cols - 1) + rowsep print(rowsep) - nsyms = len(ds["rotations"]) + nsyms = len(ds.rotations) nrows = nsyms // n_sym_cols + int(nsyms % n_sym_cols != 0) def p(*s): print(*s, sep=" ") for irow in range(nrows): - rots = ds["rotations"][irow * n_sym_cols:(irow + 1) * n_sym_cols] - trans = ds["translations"][irow * n_sym_cols:(irow + 1) * n_sym_cols] + rots = ds.rotations[irow * n_sym_cols:(irow + 1) * n_sym_cols] + trans = ds.translations[irow * n_sym_cols:(irow + 1) * n_sym_cols] p(*[fmtstr.format("", *rot[0, :], tran[0]) for rot, tran in zip(rots, trans)]) p(*[fmtstr.format(str(i + 1 + irow * n_sym_cols), *rot[1, :], tran[1]) diff --git a/mushroom/core/test/test_crystutils.py b/mushroom/core/test/test_crystutils.py index 44ad388..466a9c0 100644 --- a/mushroom/core/test/test_crystutils.py +++ b/mushroom/core/test/test_crystutils.py @@ -110,7 +110,6 @@ def test_get_density(self): self.assertAlmostEqual(density_m, 0.001008e27 / NAV) - class test_symmetry_related(ut.TestCase): def test_display_symmetry_info(self): @@ -127,5 +126,6 @@ def test_display_symmetry_info(self): atms = [11, 17] display_symmetry_info(latt, posi, atms) + if __name__ == "__main__": ut.main()