From c9b9dc3c400408a6f145475afa121889c46955c5 Mon Sep 17 00:00:00 2001 From: larsevj Date: Sun, 25 Aug 2024 15:00:40 +0200 Subject: [PATCH] Use ruff as linter and formatter --- .github/workflows/style.yml | 13 +- cmake/create_cmakelists.py | 2 +- pyproject.toml | 41 +++++ python/docs/examples/avg_pressure.py | 13 +- python/docs/examples/grid_info.py | 5 +- python/resdata/__init__.py | 8 +- python/resdata/geometry/cpolyline.py | 12 +- .../resdata/geometry/cpolyline_collection.py | 9 +- python/resdata/geometry/geo_region.py | 4 +- python/resdata/geometry/geometry_tools.py | 37 ++--- python/resdata/geometry/polyline.py | 12 +- python/resdata/geometry/surface.py | 9 +- python/resdata/gravimetry/rd_subsidence.py | 10 +- python/resdata/grid/cell.py | 2 +- python/resdata/grid/faults/fault.py | 59 +++---- python/resdata/grid/faults/fault_block.py | 5 +- .../resdata/grid/faults/fault_block_layer.py | 4 +- python/resdata/grid/faults/fault_line.py | 14 +- python/resdata/grid/faults/fault_segments.py | 9 +- python/resdata/grid/faults/layer.py | 4 +- python/resdata/grid/rd_grid.py | 44 ++--- python/resdata/grid/rd_grid_generator.py | 2 +- python/resdata/grid/rd_region.py | 5 +- python/resdata/rd_util.py | 6 +- python/resdata/resfile/rd_3d_file.py | 9 +- python/resdata/resfile/rd_3dkw.py | 24 ++- python/resdata/resfile/rd_file.py | 18 +- python/resdata/resfile/rd_file_view.py | 10 +- python/resdata/resfile/rd_kw.py | 154 ++++++++---------- python/resdata/resfile/rd_restart_file.py | 6 +- python/resdata/rft/rd_rft.py | 6 +- python/resdata/rft/well_trajectory.py | 8 +- python/resdata/summary/rd_cmp.py | 13 +- python/resdata/summary/rd_npv.py | 11 +- python/resdata/summary/rd_smspec_node.py | 2 +- python/resdata/summary/rd_sum.py | 38 +++-- python/resdata/summary/rd_sum_vector.py | 4 + python/resdata/util/test/extended_testcase.py | 9 +- python/resdata/util/test/mock/rd_sum_mock.py | 3 +- python/resdata/util/test/path_context.py | 5 +- .../resdata/util/test/resdata_test_runner.py | 2 +- python/resdata/util/test/source_enumerator.py | 6 +- python/resdata/util/test/test_run.py | 2 +- python/resdata/util/util/__init__.py | 2 +- python/resdata/util/util/ctime.py | 3 +- python/resdata/util/util/lookup_table.py | 12 +- python/resdata/util/util/thread_pool.py | 14 +- python/resdata/util/util/time_vector.py | 6 +- python/resdata/util/util/vector_template.py | 32 ++-- python/resdata/util/util/version.py | 7 +- python/resdata/well/well_state.py | 5 +- python/tests/__init__.py | 11 +- python/tests/geometry_tests/test_surface.py | 6 +- python/tests/rd_tests/test_fault_blocks.py | 1 - python/tests/rd_tests/test_faults.py | 4 +- python/tests/rd_tests/test_fortio.py | 5 +- python/tests/rd_tests/test_geertsma.py | 28 ---- python/tests/rd_tests/test_grid.py | 2 +- python/tests/rd_tests/test_grid_equinor.py | 17 +- python/tests/rd_tests/test_grid_generator.py | 10 +- python/tests/rd_tests/test_npv.py | 2 +- python/tests/rd_tests/test_rd_cmp.py | 2 +- python/tests/rd_tests/test_rd_kw.py | 23 ++- python/tests/rd_tests/test_rd_sum.py | 6 +- python/tests/rd_tests/test_rd_type.py | 6 +- python/tests/rd_tests/test_rd_util.py | 2 +- python/tests/rd_tests/test_region_equinor.py | 2 +- python/tests/rd_tests/test_sum.py | 30 +--- python/tests/rd_tests/test_sum_equinor.py | 15 +- python/tests/test_bin.py | 2 +- python/tests/util_tests/test_path_context.py | 16 +- python/tests/util_tests/test_string_list.py | 2 +- python/tests/util_tests/test_thread_pool.py | 2 +- python/tests/util_tests/test_vectors.py | 9 +- python/tests/well_tests/test_rd_well.py | 2 +- setup.cfg | 5 - 76 files changed, 412 insertions(+), 538 deletions(-) delete mode 100644 setup.cfg diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 0dbb6e269..cf52b2cea 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -13,7 +13,7 @@ jobs: - name: Install dependencies run: | - sudo pip install cmake-format black + sudo pip install cmake-format - name: Clang Format run: ./script/clang-format --check @@ -23,6 +23,13 @@ jobs: find . -name 'CMakeLists.txt' -o -name '*.cmake' > cmake-src xargs cmake-format --check < cmake-src - - name: Black + - name: Setup python + uses: actions/setup-python@v5 + with: + python-version: ["3.12"] + + - name: Run ruff run: | - black --check . + pip install ruff + ruff check + ruff format --check diff --git a/cmake/create_cmakelists.py b/cmake/create_cmakelists.py index 173331e6b..a5c5fd370 100755 --- a/cmake/create_cmakelists.py +++ b/cmake/create_cmakelists.py @@ -11,7 +11,7 @@ def findFilesAndDirectories(directory): directories = [] for f in all_files: path = join(directory, f) - if isfile(path) and not f == "CMakeLists.txt" and not islink(path): + if isfile(path) and f != "CMakeLists.txt" and not islink(path): files.append(f) if isdir(path): directories.append(f) diff --git a/pyproject.toml b/pyproject.toml index 9abf16845..24c3f4e65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,43 @@ [build-system] requires = ["setuptools", "setuptools_scm", "wheel", "scikit-build", "cmake", "conan<2", "ninja"] + +[tool.ruff] +src = ["python"] +line-length = 88 + +[tool.ruff.lint] +select = [ + "W", # pycodestyle + # "I", # isort ( Issues with circular imports and cwrap) + "B", # flake-8-bugbear + "SIM", # flake-8-simplify + "F", # pyflakes + "PL", # pylint + "NPY", # numpy specific rules + "C4", # flake8-comprehensions + "PD", # pandas-vet +] +ignore = ["PLW2901", # redefined-loop-name + "PLR2004", # magic-value-comparison + "PLR0915", # too-many-statements + "PLR0912", # too-many-branches + "PLR0911", # too-many-return-statements + "PLC2701", # import-private-name + "PLR6201", # literal-membership + "PLR0914", # too-many-locals + "PLR6301", # no-self-use + "PLW1641", # eq-without-hash + "PLR0904", # too-many-public-methods + "PLR1702", # too-many-nested-blocks + "PLW3201", # bad-dunder-method-name + "PD901", + "C409", + "PLC0414", + "F401", + "F841", +] +[tool.ruff.lint.extend-per-file-ignores] +"python/tests/util_tests/test_ctime.py" = ["PLR0124", "B015"] + +[tool.ruff.lint.pylint] +max-args = 15 diff --git a/python/docs/examples/avg_pressure.py b/python/docs/examples/avg_pressure.py index 7a0131dff..db97178a7 100755 --- a/python/docs/examples/avg_pressure.py +++ b/python/docs/examples/avg_pressure.py @@ -2,6 +2,7 @@ import sys import matplotlib.pyplot as plt + from resdata.grid import Grid, ResdataRegion from resdata.resfile import ResdataFile, ResdataRestartFile @@ -19,15 +20,9 @@ def avg_pressure(p, sw, pv, region, region_id, result): p1 = p.sum(mask=region) / region.active_size() - if total_pv > 0: - p2 = p_pv.sum(mask=region) / total_pv - else: - p2 = None + p2 = p_pv.sum(mask=region) / total_pv if total_pv > 0 else None - if total_hc_pv > 0: - p3 = p_hc_pv.sum(mask=region) / total_hc_pv - else: - p3 = None + p3 = p_hc_pv.sum(mask=region) / total_hc_pv if total_hc_pv > 0 else None else: p1 = None p2 = None @@ -71,7 +66,7 @@ def avg_pressure(p, sw, pv, region, region_id, result): avg_pressure(p, sw, pv, ResdataRegion(grid, True), "field", result) sim_days.append(header.get_sim_days()) - for key in result.keys(): + for key in result: plt.figure(1) for index, p in enumerate(result[key]): plt.plot(sim_days, p, label="Region:%s P%d" % (key, index + 1)) diff --git a/python/docs/examples/grid_info.py b/python/docs/examples/grid_info.py index 7e442412e..e2aedcaf1 100755 --- a/python/docs/examples/grid_info.py +++ b/python/docs/examples/grid_info.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import sys -from resdata.grid import ResdataRegion, Grid + +from resdata.grid import Grid, ResdataRegion def volume_min_max(grid): @@ -26,7 +27,7 @@ def main(grid): if __name__ == "__main__": if len(sys.argv) < 2: - exit("usage: grid_info.py path/to/file.EGRID") + sys.exit("usage: grid_info.py path/to/file.EGRID") case = sys.argv[1] grid = Grid(case) main(grid) diff --git a/python/resdata/__init__.py b/python/resdata/__init__.py index 5222260f9..a369d228a 100644 --- a/python/resdata/__init__.py +++ b/python/resdata/__init__.py @@ -69,7 +69,7 @@ def _dlopen_resdata(): @ct.CFUNCTYPE(None, ct.c_char_p, ct.c_int, ct.c_char_p, ct.c_char_p, ct.c_char_p) def _c_abort_handler(filename, lineno, function, message, backtrace): - global _abort_handler + global _abort_handler # noqa: PLW0602 if not _abort_handler: return _abort_handler( @@ -85,7 +85,7 @@ def set_abort_handler(function): """ Set callback function for util_abort, which is called prior to std::abort() """ - global _abort_handler + global _abort_handler # noqa: PLW0603 _abort_handler = function ResdataPrototype.lib.util_set_abort_handler(_c_abort_handler) @@ -102,11 +102,11 @@ def __init__(self, prototype, bind=True): from .rd_type import ResDataType, ResdataTypeEnum from .rd_util import ( - FileType, FileMode, + FileType, Phase, - UnitSystem, ResdataUtil, + UnitSystem, ) from .util.util import ResdataVersion, updateAbortSignals diff --git a/python/resdata/geometry/cpolyline.py b/python/resdata/geometry/cpolyline.py index 559fe797a..fcd64949e 100644 --- a/python/resdata/geometry/cpolyline.py +++ b/python/resdata/geometry/cpolyline.py @@ -6,7 +6,9 @@ import os.path from cwrap import BaseCClass + from resdata import ResdataPrototype + from .geometry_tools import GeometryTools @@ -54,10 +56,7 @@ def createFromXYZFile(cls, filename, name=None): def __str__(self): name = self.getName() - if name: - str = "%s [" % name - else: - str = "[" + str = "%s [" % name if name else "[" for index, p in enumerate(self): str += "(%g,%g)" % p @@ -143,10 +142,7 @@ def extendToBBox(self, bbox, start=True): intersections = GeometryTools.rayPolygonIntersections(p1, ray_dir, bbox) if intersections: p2 = intersections[0][1] - if self.getName(): - name = "Extend:%s" % self.getName() - else: - name = None + name = "Extend:%s" % self.getName() if self.getName() else None return CPolyline(name=name, init_points=[(p1[0], p1[1]), p2]) else: diff --git a/python/resdata/geometry/cpolyline_collection.py b/python/resdata/geometry/cpolyline_collection.py index 80608ac31..52d2f44f0 100644 --- a/python/resdata/geometry/cpolyline_collection.py +++ b/python/resdata/geometry/cpolyline_collection.py @@ -85,11 +85,10 @@ def shallowCopy(self): def addPolyline(self, polyline, name=None): if not isinstance(polyline, CPolyline): polyline = CPolyline(init_points=polyline, name=name) - else: - if not name is None: - raise ValueError( - "The name keyword argument can only be supplied when add not CPOlyline object" - ) + elif not name is None: + raise ValueError( + "The name keyword argument can only be supplied when add not CPOlyline object" + ) name = polyline.getName() if name and name in self: diff --git a/python/resdata/geometry/geo_region.py b/python/resdata/geometry/geo_region.py index 33c7cb597..54bb00ba4 100644 --- a/python/resdata/geometry/geo_region.py +++ b/python/resdata/geometry/geo_region.py @@ -42,7 +42,7 @@ class GeoRegion(BaseCClass): ) def __init__(self, pointset, preselect=False): - self._preselect = True if preselect else False + self._preselect = bool(preselect) c_ptr = self._alloc(pointset, self._preselect) if c_ptr: super(GeoRegion, self).__init__(c_ptr) @@ -68,7 +68,7 @@ def _construct_cline(self, line): x2, y2 = map(float, p2) except Exception as err: err_msg = "Select with pair ((x1,y1), (x2,y2)), not %s (%s)." - raise ValueError(err_msg % (line, err)) + raise ValueError(err_msg % (line, err)) from err x1x2_ptr = cpair(x1, x2) y1y2_ptr = cpair(y1, y2) return x1x2_ptr, y1y2_ptr diff --git a/python/resdata/geometry/geometry_tools.py b/python/resdata/geometry/geometry_tools.py index 19d6594c4..1bf0845ee 100644 --- a/python/resdata/geometry/geometry_tools.py +++ b/python/resdata/geometry/geometry_tools.py @@ -1,6 +1,6 @@ -from math import sqrt import functools import sys +from math import sqrt class GeometryTools(object): @@ -135,13 +135,12 @@ def pointInPolygon(p, polygon): for index in range(n + 1): p2x, p2y = polygon[index % n][0:2] - if min(p1y, p2y) < y <= max(p1y, p2y): - if x <= max(p1x, p2x): - if p1y != p2y: - xints = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x + if min(p1y, p2y) < y <= max(p1y, p2y) and x <= max(p1x, p2x): + if p1y != p2y: + xints = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x - if p1x == p2x or x <= xints: - inside = not inside + if p1x == p2x or x <= xints: + inside = not inside p1x, p1y = p2x, p2y @@ -159,9 +158,7 @@ def extendToEdge(bounding_polygon, poly_line): ray1 = GeometryTools.lineToRay(poly_line[1], poly_line[0]) intersection1 = GeometryTools.rayPolygonIntersections( p1, ray1, bounding_polygon - )[ - 0 - ] # assume convex + )[0] # assume convex p2 = poly_line[-1] assert GeometryTools.pointInPolygon(p2, bounding_polygon) @@ -172,9 +169,7 @@ def extendToEdge(bounding_polygon, poly_line): ) intersection2 = GeometryTools.rayPolygonIntersections( p2, ray2, bounding_polygon - )[ - 0 - ] # assume convex + )[0] # assume convex return [intersection1[1]] + poly_line + [intersection2[1]] @@ -196,17 +191,13 @@ def slicePolygon(bounding_polygon, poly_line): tmp = GeometryTools.rayPolygonIntersections(p1, ray1, bounding_polygon) intersection1 = GeometryTools.rayPolygonIntersections( p1, ray1, bounding_polygon - )[ - 0 - ] # assume convex + )[0] # assume convex p2 = poly_line[-1] ray2 = GeometryTools.lineToRay(poly_line[-2], poly_line[-1]) intersection2 = GeometryTools.rayPolygonIntersections( p2, ray2, bounding_polygon - )[ - 0 - ] # assume convex + )[0] # assume convex # Check for intersection between the polyline extensions on the inside of the bounadary internal_intersection = GeometryTools.lineIntersection( @@ -394,13 +385,17 @@ def connectPolylines(polyline, target_polyline): p0 = polyline[-1] p1 = polyline[-2] ray = GeometryTools.lineToRay(p1, p0) - for index, p in GeometryTools.rayPolygonIntersections(p0, ray, target_polyline): + for _index, p in GeometryTools.rayPolygonIntersections( + p0, ray, target_polyline + ): d_list.append((GeometryTools.distance(p0, p), [p0, p])) p0 = polyline[0] p1 = polyline[1] ray = GeometryTools.lineToRay(p1, p0) - for index, p in GeometryTools.rayPolygonIntersections(p0, ray, target_polyline): + for _index, p in GeometryTools.rayPolygonIntersections( + p0, ray, target_polyline + ): d_list.append((GeometryTools.distance(p0, p), [p0, p])) if len(d_list) == 0: diff --git a/python/resdata/geometry/polyline.py b/python/resdata/geometry/polyline.py index 2fff2e0e0..1c7566802 100644 --- a/python/resdata/geometry/polyline.py +++ b/python/resdata/geometry/polyline.py @@ -1,4 +1,5 @@ import collections + from .geometry_tools import GeometryTools @@ -40,20 +41,13 @@ def __eq__(self, other): if len(self) != len(other): return False - for p1, p2 in zip(self, other): - if p1 != p2: - return False - - return True + return all(p1 == p2 for p1, p2 in zip(self, other)) def __len__(self): return len(self.__points) def addPoint(self, x, y, z=None): - if z is None: - p = (x, y) - else: - p = (x, y, z) + p = (x, y) if z is None else (x, y, z) self.__points.append(p) def __getitem__(self, index): diff --git a/python/resdata/geometry/surface.py b/python/resdata/geometry/surface.py index 92150fb50..a361fa2a2 100644 --- a/python/resdata/geometry/surface.py +++ b/python/resdata/geometry/surface.py @@ -255,11 +255,10 @@ def _assert_idx_or_i_and_j(self, idx, i, j): raise ValueError( "idx is None, i and j must be ints, was %s and %s." % (i, j) ) - else: - if i is not None or j is not None: - raise ValueError( - "idx is set, i and j must be None, was %s and %s." % (i, j) - ) + elif i is not None or j is not None: + raise ValueError( + "idx is set, i and j must be None, was %s and %s." % (i, j) + ) def getXYZ(self, idx=None, i=None, j=None): """Returns a tuple of 3 floats, (x,y,z) for given global index, or i and j.""" diff --git a/python/resdata/gravimetry/rd_subsidence.py b/python/resdata/gravimetry/rd_subsidence.py index 0ad6b7bb8..4597dbab5 100644 --- a/python/resdata/gravimetry/rd_subsidence.py +++ b/python/resdata/gravimetry/rd_subsidence.py @@ -98,9 +98,8 @@ def eval_geertsma( if not base_survey in self: raise KeyError("No such survey: %s" % base_survey) - if monitor_survey is not None: - if not monitor_survey in self: - raise KeyError("No such survey: %s" % monitor_survey) + if monitor_survey is not None and not monitor_survey in self: + raise KeyError("No such survey: %s" % monitor_survey) return self._eval_geertsma( base_survey, @@ -127,9 +126,8 @@ def eval_geertsma_rporv( if not base_survey in self: raise KeyError("No such survey: %s" % base_survey) - if monitor_survey is not None: - if not monitor_survey in self: - raise KeyError("No such survey: %s" % monitor_survey) + if monitor_survey is not None and not monitor_survey in self: + raise KeyError("No such survey: %s" % monitor_survey) return self._eval_geertsma_rporv( base_survey, diff --git a/python/resdata/grid/cell.py b/python/resdata/grid/cell.py index 9b17c10f7..ff0de4df1 100644 --- a/python/resdata/grid/cell.py +++ b/python/resdata/grid/cell.py @@ -81,7 +81,7 @@ def __eq__(self, other): def __neq__(self, other): if isinstance(other, Cell): - return not self == other + return self != other return NotImplemented def hash(self): diff --git a/python/resdata/grid/faults/fault.py b/python/resdata/grid/faults/fault.py index b25a42758..892261947 100644 --- a/python/resdata/grid/faults/fault.py +++ b/python/resdata/grid/faults/fault.py @@ -100,7 +100,7 @@ def __sort_fault_lines(self): perm_list.sort(key=lambda x: x[1]) fault_lines = [] - for index, d in perm_list: + for index, _d in perm_list: fault_lines.append(self.__fault_lines[index]) self.__fault_lines = fault_lines @@ -208,32 +208,29 @@ def add_record(self, I1, I2, J1, J2, K1, K2, face): if K1 > K2: raise ValueError("Invalid K1 K2 indices") - if I1 < 0 or I1 >= self.nx: + if I1 < 0 or self.nx <= I1: raise ValueError("Invalid I1:%d" % I1) - if I2 < 0 or I2 >= self.nx: + if I2 < 0 or self.nx <= I2: raise ValueError("Invalid I2:%d" % I2) - if J1 < 0 or J1 >= self.ny: + if J1 < 0 or self.ny <= J1: raise ValueError("Invalid J1:%d" % J1) - if J2 < 0 or J2 >= self.ny: + if J2 < 0 or self.ny <= J2: raise ValueError("Invalid J2:%d" % J2) - if K1 < 0 or K1 >= self.nz: + if K1 < 0 or self.nz <= K1: raise ValueError("Invalid K1:%d" % K1) - if K2 < 0 or K2 >= self.nz: + if K2 < 0 or self.nz <= K2: raise ValueError("Invalid K2:%d" % K2) - if face in ["X", "I"]: - if I1 != I2: - raise ValueError("For face:%s we must have I1 == I2" % face) + if face in ["X", "I"] and I1 != I2: + raise ValueError("For face:%s we must have I1 == I2" % face) - if face in ["Y", "J"]: - if J1 != J2: - raise ValueError("For face:%s we must have J1 == J2" % face) + if face in ["Y", "J"] and J1 != J2: + raise ValueError("For face:%s we must have J1 == J2" % face) - if face in ["Z", "K"]: - if K1 != K2: - raise ValueError("For face:%s we must have K1 == K2" % face) + if face in ["Z", "K"] and K1 != K2: + raise ValueError("For face:%s we must have K1 == K2" % face) # ----------------------------------------------------------------- @@ -297,10 +294,7 @@ def connect_with_polyline(self, polyline, k): return None def connect(self, target, k): - if isinstance(target, Fault): - polyline = target.getPolyline(k) - else: - polyline = target + polyline = target.getPolyline(k) if isinstance(target, Fault) else target return self.connectWithPolyline(polyline, k) def extend_to_polyline(self, polyline, k): @@ -402,10 +396,7 @@ def extend_to_b_box(self, bbox, k, start=True): intersections = GeometryTools.rayPolygonIntersections(p1, ray_dir, bbox) if intersections: p2 = intersections[0][1] - if self.getName(): - name = "Extend:%s" % self.getName() - else: - name = None + name = "Extend:%s" % self.getName() if self.getName() else None return CPolyline(name=name, init_points=[(p1[0], p1[1]), p2]) else: @@ -414,10 +405,7 @@ def extend_to_b_box(self, bbox, k, start=True): def end_join(self, other, k): fault_polyline = self.getPolyline(k) - if isinstance(other, Fault): - other_polyline = other.getPolyline(k) - else: - other_polyline = other + other_polyline = other.getPolyline(k) if isinstance(other, Fault) else other return GeometryTools.joinPolylines(fault_polyline, other_polyline) @@ -462,13 +450,11 @@ def intersect_fault_rays(ray1, ray2): dx = p2[0] - p1[0] dy = p2[1] - p1[1] - if dx != 0: - if dir1[0] * dx <= 0 and dir2[0] * dx >= 0: - raise ValueError("Rays will never intersect") + if dx != 0 and dir1[0] * dx <= 0 and dir2[0] * dx >= 0: + raise ValueError("Rays will never intersect") - if dy != 0: - if dir1[1] * dy <= 0 and dir2[1] * dy >= 0: - raise ValueError("Rays will never intersect") + if dy != 0 and dir1[1] * dy <= 0 and dir2[1] * dy >= 0: + raise ValueError("Rays will never intersect") if dx * dy != 0: if dir1[0] != 0: @@ -502,10 +488,7 @@ def int_ray(p1, p2): raise Exception("Invalid direction") dy = 0 - if p2[0] > p1[0]: - dx = 1 - else: - dx = -1 + dx = 1 if p2[0] > p1[0] else -1 return [p2, (dx, dy)] diff --git a/python/resdata/grid/faults/fault_block.py b/python/resdata/grid/faults/fault_block.py index a74a4c17d..9ad0f6090 100644 --- a/python/resdata/grid/faults/fault_block.py +++ b/python/resdata/grid/faults/fault_block.py @@ -143,9 +143,8 @@ def contains_polyline(self, polyline): for p in polyline: if GeometryTools.pointInPolygon(p, edge_polyline): return True - else: - edge_polyline.assertClosed() - return GeometryTools.polylinesIntersect(edge_polyline, polyline) + edge_polyline.assertClosed() + return GeometryTools.polylinesIntersect(edge_polyline, polyline) def get_neighbours(self, polylines=None, connected_only=True): """ diff --git a/python/resdata/grid/faults/fault_block_layer.py b/python/resdata/grid/faults/fault_block_layer.py index 3176456a9..a2a4a175d 100644 --- a/python/resdata/grid/faults/fault_block_layer.py +++ b/python/resdata/grid/faults/fault_block_layer.py @@ -202,11 +202,11 @@ def join_faults(self, fault1, fault2): layer = self.getGeoLayer() try: layer.addIJBarrier(Fault.joinFaults(fault1, fault2, self.getK())) - except ValueError: + except ValueError as err: err = "Failed to join faults %s and %s" names = (fault1.getName(), fault2.getName()) print(err % names) - raise ValueError(err % names) + raise ValueError(err % names) from err def add_polyline_barrier(self, polyline): layer = self.getGeoLayer() diff --git a/python/resdata/grid/faults/fault_line.py b/python/resdata/grid/faults/fault_line.py index 45697e800..395029128 100644 --- a/python/resdata/grid/faults/fault_line.py +++ b/python/resdata/grid/faults/fault_line.py @@ -29,7 +29,7 @@ def verify(self): if len(self.__segment_list) > 1: current = self.__segment_list[0] for next_segment in self.__segment_list[1:]: - if not current.getC2() == next_segment.getC1(): + if current.getC2() != next_segment.getC1(): sys.stdout.write( "Current: %d ---- %d \n" % (current.getC1(), current.getC2()) ) @@ -54,7 +54,7 @@ def try_append(self, segment): else: segment.swap() - if not tail.getC2() == segment.getC1(): + if tail.getC2() != segment.getC1(): return False self.__segment_list.append(segment) @@ -124,10 +124,7 @@ def __init_neighbor_cells(self): i = i1 for j in range(j1, j2): g2 = i + j * nx + k * nx * ny - if i == 0: - g1 = -1 - else: - g1 = g2 - 1 + g1 = -1 if i == 0 else g2 - 1 if i == nx: g2 = -1 @@ -137,10 +134,7 @@ def __init_neighbor_cells(self): j = j1 for i in range(i1, i2): g2 = i + j * nx + k * nx * ny - if j == 0: - g1 = -1 - else: - g1 = g2 - nx + g1 = -1 if j == 0 else g2 - nx if j == ny: g2 = -1 diff --git a/python/resdata/grid/faults/fault_segments.py b/python/resdata/grid/faults/fault_segments.py index b5ff6e91a..abc620539 100644 --- a/python/resdata/grid/faults/fault_segments.py +++ b/python/resdata/grid/faults/fault_segments.py @@ -13,7 +13,7 @@ def __eq__(self, other): s = self.c1, self.c2 o = other.c1, other.c2 o_flipped = other.c2, other.c1 - return s == o or s == o_flipped + return s in (o, o_flipped) def __hash__(self): return hash(hash(self.__C1) + hash(self.__C2) + hash(self.__next_segment)) @@ -28,10 +28,7 @@ def joins(self, other): return True if self.__C2 == other.__C1: return True - if self.__C2 == other.__C2: - return True - - return False + return self.__C2 == other.__C2 def get_c1(self): return self.__C1 @@ -135,7 +132,7 @@ def pop_next(self, segment): def print_content(self): for d in self.__segment_map.values(): - for C, S in d.iteritems(): + for _C, S in d.iteritems(): print(S) diff --git a/python/resdata/grid/faults/layer.py b/python/resdata/grid/faults/layer.py index eaf263f42..7f5f6f61f 100644 --- a/python/resdata/grid/faults/layer.py +++ b/python/resdata/grid/faults/layer.py @@ -63,10 +63,10 @@ def _assert_ij(self, i, j): def __unpack_index(self, index): try: (i, j) = index - except TypeError: + except TypeError as err: raise ValueError( "Index:%s is invalid - must have two integers" % str(index) - ) + ) from err self._assert_ij(i, j) diff --git a/python/resdata/grid/rd_grid.py b/python/resdata/grid/rd_grid.py index d20c236a6..fb1bf8a87 100644 --- a/python/resdata/grid/rd_grid.py +++ b/python/resdata/grid/rd_grid.py @@ -280,6 +280,7 @@ def create_rectangular(cls, dims, dV, actnum=None): "Grid.createRectangular is deprecated. " + "Please use the similar method: GridGenerator.createRectangular.", DeprecationWarning, + stacklevel=1, ) if actnum is None: @@ -554,12 +555,11 @@ def __global_index(self, active_index=None, global_index=None, ijk=None): raise IndexError("Invalid value k:%d Range: [%d,%d)" % (k, 0, nz)) global_index = self._get_global_index3(i, j, k) - else: - if not 0 <= global_index < self.getGlobalSize(): - raise IndexError( - "Invalid value global_index:%d Range: [%d,%d)" - % (global_index, 0, self.getGlobalSize()) - ) + elif not 0 <= global_index < self.getGlobalSize(): + raise IndexError( + "Invalid value global_index:%d Range: [%d,%d)" + % (global_index, 0, self.getGlobalSize()) + ) return global_index def get_active_index(self, ijk=None, global_index=None): @@ -642,10 +642,7 @@ def active(self, ijk=None, global_index=None): """ gi = self.__global_index(global_index=global_index, ijk=ijk) active_index = self._get_active_index1(gi) - if active_index >= 0: - return True - else: - return False + return active_index >= 0 def get_global_index(self, ijk=None, active_index=None): """ @@ -1045,10 +1042,7 @@ def has_lgr(self, lgr_name): """ Query if the grid has an LGR with name @lgr_name. """ - if self._has_named_lgr(lgr_name): - return True - else: - return False + return self._has_named_lgr(lgr_name) def get_lgr(self, lgr_key): """Get Grid instance with LGR content. @@ -1067,9 +1061,8 @@ def get_lgr(self, lgr_key): if isinstance(lgr_key, int): if self._has_numbered_lgr(lgr_key): lgr = self._get_numbered_lgr(lgr_key) - else: - if self._has_named_lgr(lgr_key): - lgr = self._get_named_lgr(lgr_key) + elif self._has_named_lgr(lgr_key): + lgr = self._get_named_lgr(lgr_key) if lgr is None: raise KeyError("No such LGR: %s" % lgr_key) @@ -1157,10 +1150,7 @@ def create_kw(self, array, kw_name, pack): else: sys.exit("Do not know how to create rd_kw from type:%s" % dtype) - if pack: - size = self.getNumActive() - else: - size = self.getGlobalSize() + size = self.getNumActive() if pack else self.getGlobalSize() if len(kw_name) > 8: # Silently truncate to length 8 @@ -1176,11 +1166,10 @@ def create_kw(self, array, kw_name, pack): if self.active(global_index=global_index): kw[active_index] = array[i, j, k] active_index += 1 + elif dtype == numpy.int32: + kw[global_index] = int(array[i, j, k]) else: - if dtype == numpy.int32: - kw[global_index] = int(array[i, j, k]) - else: - kw[global_index] = array[i, j, k] + kw[global_index] = array[i, j, k] global_index += 1 return kw @@ -1379,10 +1368,7 @@ def export_index(self, active_only=False): This index frame should typically be passed to the epxport_data(), export_volume() and export_corners() functions. """ - if active_only: - size = self.get_num_active() - else: - size = self.get_global_size() + size = self.get_num_active() if active_only else self.get_global_size() indx = numpy.zeros(size, dtype=numpy.int32) data = numpy.zeros([size, 4], dtype=numpy.int32) self._export_index_frame( diff --git a/python/resdata/grid/rd_grid_generator.py b/python/resdata/grid/rd_grid_generator.py index 44f9a958f..c02382a6a 100644 --- a/python/resdata/grid/rd_grid_generator.py +++ b/python/resdata/grid/rd_grid_generator.py @@ -492,7 +492,7 @@ def assert_actnum(cls, nx, ny, nz, actnum): % (nx * ny * nz, len(actnum)) ) - if set(actnum) - set([0, 1]): + if set(actnum) - {0, 1}: raise AssertionError( "Expected ACTNUM to consist of 0's and 1's, was %s." % ", ".join(map(str, set(actnum))) diff --git a/python/resdata/grid/rd_region.py b/python/resdata/grid/rd_region.py index 69ad6e24b..28f2d1ada 100644 --- a/python/resdata/grid/rd_region.py +++ b/python/resdata/grid/rd_region.py @@ -1067,10 +1067,7 @@ def idiv_kw(self, target_kw, other, force_active=False): else: raise TypeError("Type mismatch") else: - if target_kw.data_type.is_int(): - scale = 1 // other - else: - scale = 1.0 / other + scale = 1 // other if target_kw.data_type.is_int() else 1.0 / other self.scale_kw(target_kw, scale, force_active) def copy_kw(self, target_kw, src_kw, force_active=False): diff --git a/python/resdata/rd_util.py b/python/resdata/rd_util.py index 50932e08f..1454330fd 100644 --- a/python/resdata/rd_util.py +++ b/python/resdata/rd_util.py @@ -15,6 +15,7 @@ import ctypes from cwrap import BaseCEnum + from resdata import ResdataPrototype from resdata.util.util import monkey_the_camel @@ -140,10 +141,7 @@ def inspect_extension(filename): file_type = ResdataUtil._get_file_type( filename, ctypes.byref(fmt_file), ctypes.byref(report_step) ) - if report_step.value == -1: - step = None - else: - step = report_step.value + step = None if report_step.value == -1 else report_step.value return (file_type, fmt_file.value, step) diff --git a/python/resdata/resfile/rd_3d_file.py b/python/resdata/resfile/rd_3d_file.py index 0feed60ff..c5bbc2c64 100644 --- a/python/resdata/resfile/rd_3d_file.py +++ b/python/resdata/resfile/rd_3d_file.py @@ -1,4 +1,4 @@ -from resdata.resfile import ResdataFile, Resdata3DKW +from resdata.resfile import Resdata3DKW, ResdataFile class Resdata3DFile(ResdataFile): @@ -8,14 +8,11 @@ def __init__(self, grid, filename, flags=0): def __getitem__(self, index): return_arg = super(Resdata3DFile, self).__getitem__(index) - if isinstance(return_arg, list): - kw_list = return_arg - else: - kw_list = [return_arg] + kw_list = return_arg if isinstance(return_arg, list) else [return_arg] # Go through all the keywords and try inplace promotion to Resdata3DKW for kw in kw_list: - try: + try: # noqa: SIM105 Resdata3DKW.castFromKW(kw, self.grid) except ValueError: pass diff --git a/python/resdata/resfile/rd_3dkw.py b/python/resdata/resfile/rd_3dkw.py index 791ae8d22..5c0e6e825 100644 --- a/python/resdata/resfile/rd_3dkw.py +++ b/python/resdata/resfile/rd_3dkw.py @@ -1,4 +1,5 @@ from resdata.util.util import monkey_the_camel + from .rd_kw import ResdataKW @@ -51,10 +52,7 @@ class Resdata3DKW(ResdataKW): """ def __init__(self, kw, grid, value_type, default_value=0, global_active=False): - if global_active: - size = grid.getGlobalSize() - else: - size = grid.getNumActive() + size = grid.getGlobalSize() if global_active else grid.getNumActive() super(Resdata3DKW, self).__init__(kw, size, value_type) self.grid = grid self.global_active = global_active @@ -93,11 +91,10 @@ def __getitem__(self, index): global_index = self.grid.get_global_index(ijk=index) if self.global_active: index = global_index + elif not self.grid.active(global_index=global_index): + return self.getDefault() else: - if not self.grid.active(global_index=global_index): - return self.getDefault() - else: - index = self.grid.get_active_index(ijk=index) + index = self.grid.get_active_index(ijk=index) return super(Resdata3DKW, self).__getitem__(index) @@ -116,13 +113,12 @@ def __setitem__(self, index, value): global_index = self.grid.get_global_index(ijk=index) if self.global_active: index = global_index + elif not self.grid.active(global_index=global_index): + raise ValueError( + "Tried to assign value to inactive cell: (%d,%d,%d)" % index + ) else: - if not self.grid.active(global_index=global_index): - raise ValueError( - "Tried to assign value to inactive cell: (%d,%d,%d)" % index - ) - else: - index = self.grid.get_active_index(ijk=index) + index = self.grid.get_active_index(ijk=index) return super(Resdata3DKW, self).__setitem__(index, value) diff --git a/python/resdata/resfile/rd_file.py b/python/resdata/resfile/rd_file.py index 4b4303159..c2e6ac5d0 100644 --- a/python/resdata/resfile/rd_file.py +++ b/python/resdata/resfile/rd_file.py @@ -26,6 +26,7 @@ import re from cwrap import BaseCClass + from resdata import FileMode, FileType, ResdataPrototype from resdata.resfile import ResdataKW from resdata.util.util import CTime, monkey_the_camel @@ -121,19 +122,19 @@ def report_list(self): seqnum_list = self["SEQNUM"] for s in seqnum_list: report_steps.append(s[0]) - except KeyError: + except KeyError as err: # OK - we did not have seqnum; that might be because this # a non-unified restart file; or because this is not a # restart file at all. fname = self.getFilename() - matchObj = re.search("\.[XF](\d{4})$", fname) + matchObj = re.search(r"\.[XF](\d{4})$", fname) if matchObj: report_steps.append(int(matchObj.group(1))) else: raise TypeError( 'Tried get list of report steps from file "%s" - which is not a restart file' % fname - ) + ) from err return report_steps @@ -364,13 +365,12 @@ def restart_get_kw(self, kw_name, dtime, copy=False): return ResdataKW.copy(kw) else: return kw + elif self.has_kw(kw_name): + raise IndexError( + 'Does not have keyword "%s" at time:%s.' % (kw_name, dtime) + ) else: - if self.has_kw(kw_name): - raise IndexError( - 'Does not have keyword "%s" at time:%s.' % (kw_name, dtime) - ) - else: - raise KeyError('Keyword "%s" not recognized.' % kw_name) + raise KeyError('Keyword "%s" not recognized.' % kw_name) else: raise IndexError( 'Does not have keyword "%s" at time:%s.' % (kw_name, dtime) diff --git a/python/resdata/resfile/rd_file_view.py b/python/resdata/resfile/rd_file_view.py index 44bb363cf..e4abc9977 100644 --- a/python/resdata/resfile/rd_file_view.py +++ b/python/resdata/resfile/rd_file_view.py @@ -115,10 +115,7 @@ def __len__(self): return self._get_size() def __contains__(self, kw): - if self.numKeywords(kw) > 0: - return True - else: - return False + return self.numKeywords(kw) > 0 def num_keywords(self, kw): return self._get_num_named_kw(kw) @@ -143,9 +140,8 @@ def block_view2(self, start_kw, stop_kw, start_index): "Index must be in [0, %d), was: %d." % (ls, start_index) ) - if stop_kw: - if not stop_kw in self: - raise KeyError("The keyword:%s is not in file" % stop_kw) + if stop_kw and not stop_kw in self: + raise KeyError("The keyword:%s is not in file" % stop_kw) view = self._create_block_view2(start_kw, stop_kw, idx) view.setParent(parent=self) diff --git a/python/resdata/resfile/rd_kw.py b/python/resdata/resfile/rd_kw.py index 256c90e30..8613c832e 100644 --- a/python/resdata/resfile/rd_kw.py +++ b/python/resdata/resfile/rd_kw.py @@ -28,6 +28,7 @@ import numpy from cwrap import CFILE, BaseCClass + from resdata import ResdataPrototype, ResDataType, ResdataTypeEnum, ResdataUtil from resdata.util.util import monkey_the_camel @@ -38,6 +39,7 @@ def dump_type_deprecation_warning(): warnings.warn( "ResdataTypeEnum is deprecated. You should instead provide an ResDataType", DeprecationWarning, + stacklevel=1, ) @@ -82,18 +84,16 @@ class ResdataKW(BaseCClass): limit the operation to a part of the ResdataKW. """ - int_kw_set = set( - [ - "PVTNUM", - "FIPNUM", - "EQLNUM", - "FLUXNUM", - "MULTNUM", - "ACTNUM", - "SPECGRID", - "REGIONS", - ] - ) + int_kw_set = { + "PVTNUM", + "FIPNUM", + "EQLNUM", + "FLUXNUM", + "MULTNUM", + "ACTNUM", + "SPECGRID", + "REGIONS", + } TYPE_NAME = "rd_kw" _alloc_new = ResdataPrototype( @@ -312,12 +312,10 @@ def read_grdecl(cls, fileH, kw, strict=True, rd_type=None): """ cfile = CFILE(fileH) - if kw: - if len(kw) > 8: - raise TypeError( - "Sorry keyword:%s is too long, must be eight characters or less." - % kw - ) + if kw and len(kw) > 8: + raise TypeError( + "Sorry keyword:%s is too long, must be eight characters or less." % kw + ) if rd_type is None: if cls.int_kw_set.__contains__(kw): @@ -502,18 +500,16 @@ def __getitem__(self, index): if index < 0 or index >= length: raise IndexError + elif self.data_ptr: + return self.data_ptr[index] + elif self.data_type.is_bool(): + return self._iget_bool(index) + elif self.data_type.is_char(): + return self._iget_char_ptr(index) + elif self.data_type.is_string(): + return self._iget_string_ptr(index) else: - if self.data_ptr: - return self.data_ptr[index] - else: - if self.data_type.is_bool(): - return self._iget_bool(index) - elif self.data_type.is_char(): - return self._iget_char_ptr(index) - elif self.data_type.is_string(): - return self._iget_string_ptr(index) - else: - raise TypeError("Internal implementation error ...") + raise TypeError("Internal implementation error ...") elif isinstance(index, slice): return self.slice_copy(index) else: @@ -531,18 +527,16 @@ def __setitem__(self, index, value): if index < 0 or index >= length: raise IndexError + elif self.data_ptr: + self.data_ptr[index] = value + elif self.data_type.is_bool(): + self._iset_bool(index, value) + elif self.data_type.is_char(): + return self._iset_char_ptr(index, value) + elif self.data_type.is_string(): + return self._iset_string_ptr(index, value) else: - if self.data_ptr: - self.data_ptr[index] = value - else: - if self.data_type.is_bool(): - self._iset_bool(index, value) - elif self.data_type.is_char(): - return self._iset_char_ptr(index, value) - elif self.data_type.is_string(): - return self._iset_string_ptr(index, value) - else: - raise SystemError("Internal implementation error ...") + raise SystemError("Internal implementation error ...") elif isinstance(index, slice): (start, stop, step) = index.indices(len(self)) index = start @@ -573,11 +567,10 @@ def __IMUL__(self, factor, mul=True): self._scale_int(factor) else: raise TypeError("Type mismatch") + elif isinstance(factor, (int, float)): + self._scale_float(factor) else: - if isinstance(factor, int) or isinstance(factor, float): - self._scale_float(factor) - else: - raise TypeError("Only muliplication with scalar supported") + raise TypeError("Only muliplication with scalar supported") else: raise TypeError("Not numeric type") @@ -594,23 +587,19 @@ def __IADD__(self, delta, add=True): else: raise TypeError("Type / size mismatch") else: - if add: - sign = 1 - else: - sign = -1 + sign = 1 if add else -1 if self.data_type.is_int(): if isinstance(delta, int): self._shift_int(delta * sign) else: raise TypeError("Type mismatch") + elif isinstance(delta, (int, float)): + self._shift_float( + delta * sign + ) # Will call the _float() or _double() function in the C layer. else: - if isinstance(delta, int) or isinstance(delta, float): - self._shift_float( - delta * sign - ) # Will call the _float() or _double() function in the C layer. - else: - raise TypeError("Type mismatch") + raise TypeError("Type mismatch") else: raise TypeError("Type / size mismatch") @@ -696,9 +685,7 @@ def sum(self, mask=None, force_active=False): if mask is None: if self.data_type.is_int(): return self._int_sum() - elif self.data_type.is_float(): - return self._float_sum() - elif self.data_type.is_double(): + elif self.data_type.is_float() or self.data_type.is_double(): return self._float_sum() elif self.data_type.is_bool(): sum = 0 @@ -765,25 +752,21 @@ def assign(self, value, mask=None, force_active=False): if type(value) == type(self): if mask is not None: mask.copy_kw(self, value, force_active) + elif self.assert_binary(value): + self._copy_data(value) else: - if self.assert_binary(value): - self._copy_data(value) - else: - raise TypeError("Type / size mismatch") - else: - if mask is not None: - mask.set_kw(self, value, force_active) + raise TypeError("Type / size mismatch") + elif mask is not None: + mask.set_kw(self, value, force_active) + elif self.data_type.is_int(): + if isinstance(value, int): + self._set_int(value) else: - if self.data_type.is_int(): - if isinstance(value, int): - self._set_int(value) - else: - raise TypeError("Type mismatch") - else: - if isinstance(value, int) or isinstance(value, float): - self._set_float(value) - else: - raise TypeError("Only muliplication with scalar supported") + raise TypeError("Type mismatch") + elif isinstance(value, (int, float)): + self._set_float(value) + else: + raise TypeError("Only muliplication with scalar supported") def add(self, other, mask=None, force_active=False): """ @@ -856,13 +839,12 @@ def cutoff(x, limit): else: for index in active_list: self.data_ptr[index] = func(self.data_ptr[index]) + elif arg: + for i in range(len(self)): + self.data_ptr[i] = func(self.data_ptr[i], arg) else: - if arg: - for i in range(len(self)): - self.data_ptr[i] = func(self.data_ptr[i], arg) - else: - for i in range(len(self)): - self.data_ptr[i] = func(self.data_ptr[i]) + for i in range(len(self)): + self.data_ptr[i] = func(self.data_ptr[i]) def equal(self, other): """ @@ -982,7 +964,11 @@ def get_min(self): @property def type(self): - warnings.warn("rd_kw.type is deprecated, use .data_type", DeprecationWarning) + warnings.warn( + "rd_kw.type is deprecated, use .data_type", + DeprecationWarning, + stacklevel=1, + ) return self._get_type() @property @@ -993,14 +979,12 @@ def data_type(self): def type_name(self): return self.data_type.type_name - def type_name(self): - return self.data_type.type_name - def get_rd_type(self): warnings.warn( "ResdataTypeEnum is deprecated. " + "You should instead provide an ResDataType", DeprecationWarning, + stacklevel=1, ) return self._get_type() @@ -1012,7 +996,7 @@ def header(self): @property def array(self): a = self.data_ptr - if not a == None: + if a != None: a.size = len(self) a.__parent__ = self # Inhibit GC return a diff --git a/python/resdata/resfile/rd_restart_file.py b/python/resdata/resfile/rd_restart_file.py index 302cdc80d..bc4c86945 100644 --- a/python/resdata/resfile/rd_restart_file.py +++ b/python/resdata/resfile/rd_restart_file.py @@ -1,4 +1,5 @@ from cwrap import BaseCClass + from resdata import FileMode, FileType, ResdataPrototype from resdata.resfile import Resdata3DFile, ResdataFile from resdata.util.util import CTime, monkey_the_camel @@ -104,10 +105,7 @@ def assert_headers(self): else: intehead_kw = self["INTEHEAD"][0] doubhead_kw = self["DOUBHEAD"][0] - if "LOGIHEAD" in self: - logihead_kw = self["LOGIHEAD"][0] - else: - logihead_kw = None + logihead_kw = self["LOGIHEAD"][0] if "LOGIHEAD" in self else None self.rst_headers.append( ResdataRestartHead( diff --git a/python/resdata/rft/rd_rft.py b/python/resdata/rft/rd_rft.py index efe6a7568..778ad5539 100644 --- a/python/resdata/rft/rd_rft.py +++ b/python/resdata/rft/rd_rft.py @@ -3,6 +3,7 @@ """ from cwrap import BaseCClass + from resdata import ResdataPrototype from resdata.rft import ResdataPLTCell, ResdataRFTCell from resdata.util.util import CTime, monkey_the_camel @@ -238,10 +239,7 @@ def size(self, well=None, date=None): >>> print "RFTs at 01/01/2010 : %d" % rftFile.size( date = datetime.date( 2010 , 1 , 1 )) """ - if date: - cdate = CTime(date) - else: - cdate = CTime(-1) + cdate = CTime(date) if date else CTime(-1) return self._get_size(well, cdate) diff --git a/python/resdata/rft/well_trajectory.py b/python/resdata/rft/well_trajectory.py index 7af9b9920..20d7722d7 100644 --- a/python/resdata/rft/well_trajectory.py +++ b/python/resdata/rft/well_trajectory.py @@ -1,6 +1,6 @@ import sys -from os.path import isfile from collections import namedtuple +from os.path import isfile TrajectoryPoint = namedtuple( "TrajectoryPoint", @@ -29,8 +29,10 @@ def _parse_point(point): utm_y = float(point[1]) md = float(point[2]) tvd = float(point[3]) - except ValueError: - raise UserWarning("Error: Failed to extract data from line %s\n" % str(point)) + except ValueError as err: + raise UserWarning( + "Error: Failed to extract data from line %s\n" % str(point) + ) from err zone = None if len(point) > 4: zone = point[4] diff --git a/python/resdata/summary/rd_cmp.py b/python/resdata/summary/rd_cmp.py index 9eba124c5..20f41c13c 100644 --- a/python/resdata/summary/rd_cmp.py +++ b/python/resdata/summary/rd_cmp.py @@ -1,4 +1,5 @@ from resdata.util.util import monkey_the_camel + from .rd_sum import Summary @@ -26,16 +27,10 @@ def load_summary(self): self.summary = Summary(self.case) def start_time_equal(self, other): - if self.summary.getDataStartTime() == other.summary.getDataStartTime(): - return True - else: - return False + return self.summary.getDataStartTime() == other.summary.getDataStartTime() def end_time_equal(self, other): - if self.summary.getEndTime() == other.summary.getEndTime(): - return True - else: - return False + return self.summary.getEndTime() == other.summary.getEndTime() def cmp_summary_vector(self, other, key, sample=100): if key in self and key in other: @@ -95,7 +90,7 @@ def end_time_equal(self): return self.test_case.endTimeEqual(self.ref_case) def cmp_summary_vector(self, key, sample=100): - """Will compare the summary vectors according to @key. + r"""Will compare the summary vectors according to @key. The comparison is based on evaluating the integrals: diff --git a/python/resdata/summary/rd_npv.py b/python/resdata/summary/rd_npv.py index f6f0dd8d8..638be8c1c 100644 --- a/python/resdata/summary/rd_npv.py +++ b/python/resdata/summary/rd_npv.py @@ -1,8 +1,9 @@ -import re import datetime import numbers +import re from resdata.util.util import monkey_the_camel + from .rd_sum import Summary @@ -51,8 +52,8 @@ def add_item(self, item): year = int(tmp[2]) date = datetime.date(year, month, day) - except Exception: - raise ValueError("First element was invalid date item") + except Exception as err: + raise ValueError("First element was invalid date item") from err if len(self.dateList): prevItem = self.dateList[-1] @@ -112,14 +113,14 @@ def eval(self, date): class ResdataNPV(object): - sumKeyRE = re.compile("[\[]([\w:,]+)[\]]") + sumKeyRE = re.compile(r"[\[]([\w:,]+)[\]]") def __init__(self, baseCase): sum = Summary(baseCase) if sum: self.baseCase = sum else: - raise Error("Failed to open summary case:%s" % baseCase) + raise OSError("Failed to open summary case:%s" % baseCase) self.expression = None self.keyList = {} self.start = None diff --git a/python/resdata/summary/rd_smspec_node.py b/python/resdata/summary/rd_smspec_node.py index e793b8c03..5460aae7b 100644 --- a/python/resdata/summary/rd_smspec_node.py +++ b/python/resdata/summary/rd_smspec_node.py @@ -52,7 +52,7 @@ def __gt__(self, other): def __eq__(self, other): return self.cmp(other) == 0 - def __hash__(self, other): + def __hash__(self): return hash(self._gen_key1()) @property diff --git a/python/resdata/summary/rd_sum.py b/python/resdata/summary/rd_sum.py index ae10c0e66..5c3ef0578 100644 --- a/python/resdata/summary/rd_sum.py +++ b/python/resdata/summary/rd_sum.py @@ -390,8 +390,10 @@ def add_t_step(self, report_step, sim_days): raise TypeError("Parameter report_step should be int, was %r" % report_step) try: float(sim_days) - except TypeError: - raise TypeError("Parameter sim_days should be float, was %r" % sim_days) + except TypeError as err: + raise TypeError( + "Parameter sim_days should be float, was %r" % sim_days + ) from err sim_seconds = sim_days * 24 * 60 * 60 tstep = self._add_tstep(report_step, sim_seconds).setParent(parent=self) @@ -407,6 +409,7 @@ def get_vector(self, key, report_only=False): warnings.warn( "The method get_vector() has been deprecated, use numpy_vector() instead", DeprecationWarning, + stacklevel=1, ) self.assertKeyValid(key) if report_only: @@ -458,6 +461,7 @@ def get_values(self, key, report_only=False): warnings.warn( "The method get_values() has been deprecated - use numpy_vector() instead.", DeprecationWarning, + stacklevel=1, ) if self.has_key(key): key_index = self._get_general_var_index(key) @@ -645,9 +649,10 @@ def _compile_headers_list( pass elif var_type == SummaryVarType.RD_SMSPEC_REGION_VAR: num = int(lst[1]) - elif var_type == SummaryVarType.RD_SMSPEC_GROUP_VAR: - wgname = lst[1] - elif var_type == SummaryVarType.RD_SMSPEC_WELL_VAR: + elif var_type in ( + SummaryVarType.RD_SMSPEC_GROUP_VAR, + SummaryVarType.RD_SMSPEC_WELL_VAR, + ): wgname = lst[1] elif var_type == SummaryVarType.RD_SMSPEC_SEGMENT_VAR: kw, wgname, num = lst @@ -799,6 +804,7 @@ def get_last_value(self, key): warnings.warn( "The function get_last_value() is deprecated, use last_value() instead", DeprecationWarning, + stacklevel=1, ) return self.last_value(key) @@ -852,10 +858,7 @@ def __len__(self): return self._data_length() def __contains__(self, key): - if self._has_key(key): - return True - else: - return False + return self._has_key(key) def assert_key_valid(self, key): if not key in self: @@ -873,6 +876,7 @@ def __getitem__(self, key): warnings.warn( "The method the [] operator will change behaviour in the future. It will then return a plain numpy vector. You are advised to change to use the numpy_vector() method right away", DeprecationWarning, + stacklevel=1, ) return self.get_vector(key) @@ -968,8 +972,7 @@ def time_range( if isinstance(start, datetime.date): start = datetime.datetime(start.year, start.month, start.day, 0, 0, 0) - if start < self.getDataStartTime(): - start = self.getDataStartTime() + start = max(start, self.getDataStartTime()) if end is None: end = self.getEndTime() @@ -977,8 +980,7 @@ def time_range( if isinstance(end, datetime.date): end = datetime.datetime(end.year, end.month, end.day, 0, 0, 0) - if end > self.getEndTime(): - end = self.getEndTime() + end = min(end, self.getEndTime()) if end < start: raise ValueError("Invalid time interval start after end") @@ -988,7 +990,7 @@ def time_range( range_start = start range_end = end - if not timeUnit == "d": + if timeUnit != "d": year1 = start.year year2 = end.year month1 = start.month @@ -1096,13 +1098,11 @@ def get_interp_vector(self, key, days_list=None, date_list=None): vector = numpy.zeros(len(days_list)) sim_length = self.sim_length sim_start = self.first_day - index = 0 - for days in days_list: + for index, days in enumerate(days_list): if (days >= sim_start) and (days <= sim_length): vector[index] = self._get_general_var_from_sim_days(days, key) else: raise ValueError("Invalid days value") - index += 1 elif date_list: start_time = self.data_start end_time = self.end_date @@ -1265,6 +1265,7 @@ def mpl_dates(self): warnings.warn( "The mpl_dates property has been deprecated - use numpy_dates instead", DeprecationWarning, + stacklevel=1, ) return self.get_mpl_dates(False) @@ -1281,6 +1282,7 @@ def get_mpl_dates(self, report_only=False): warnings.warn( "The get_mpl_dates( ) method has been deprecated - use numpy_dates instead", DeprecationWarning, + stacklevel=1, ) if report_only: return [date2num(dt) for dt in self.report_dates] @@ -1486,7 +1488,7 @@ def solve_dates(self, key, value, rates_clamp_lower=True): return [x.datetime() for x in self._solve_dates(key, value, rates_clamp_lower)] def solve_days(self, key, value, rates_clamp_lower=True): - """Will solve the equation vector[@key] == value. + r"""Will solve the equation vector[@key] == value. This method will solve find tha approximate simulation days where the vector @key is equal @value. The method will return diff --git a/python/resdata/summary/rd_sum_vector.py b/python/resdata/summary/rd_sum_vector.py index 8bdf0ba8b..362ab1217 100644 --- a/python/resdata/summary/rd_sum_vector.py +++ b/python/resdata/summary/rd_sum_vector.py @@ -1,5 +1,7 @@ from __future__ import print_function + import warnings + from .rd_sum_node import SummaryNode @@ -29,6 +31,7 @@ def __init__(self, parent, key, report_only=False): warnings.warn( "The report_only flag to the SummaryVector will be removed", DeprecationWarning, + stacklevel=1, ) self.__dates = parent.get_dates(report_only) @@ -98,6 +101,7 @@ def mpl_dates(self): warnings.warn( "The mpl_dates property has been deprecated - use numpy_dates instead", DeprecationWarning, + stacklevel=1, ) return self.parent.get_mpl_dates(self.report_only) diff --git a/python/resdata/util/test/extended_testcase.py b/python/resdata/util/test/extended_testcase.py index d1c1e1caf..a7656cc9a 100644 --- a/python/resdata/util/test/extended_testcase.py +++ b/python/resdata/util/test/extended_testcase.py @@ -130,8 +130,8 @@ def assertDirectoryDoesNotExist(self, path): self.fail("The directory: %s exists!" % path) def __filesAreEqual(self, first, second): - buffer1 = open(first, "rb").read() - buffer2 = open(second, "rb").read() + buffer1 = open(first, "rb").read() # noqa: SIM115 + buffer2 = open(second, "rb").read() # noqa: SIM115 return buffer1 == buffer2 @@ -198,7 +198,4 @@ def requireVersion(major, minor, micro="git"): required_version = Version(major, minor, micro) current_version = Version.currentVersion() - if required_version < current_version: - return True - else: - return False + return required_version < current_version diff --git a/python/resdata/util/test/mock/rd_sum_mock.py b/python/resdata/util/test/mock/rd_sum_mock.py index 0b9d13ad1..a5ee172d6 100644 --- a/python/resdata/util/test/mock/rd_sum_mock.py +++ b/python/resdata/util/test/mock/rd_sum_mock.py @@ -1,4 +1,5 @@ import datetime + from resdata.summary import Summary @@ -15,7 +16,7 @@ def createSummary( num_report_step=5, num_mini_step=10, dims=(20, 10, 5), - func_table={}, + func_table={}, # noqa: B006 restart_case=None, restart_step=-1, ): diff --git a/python/resdata/util/test/path_context.py b/python/resdata/util/test/path_context.py index d042f1d01..5d622265b 100644 --- a/python/resdata/util/test/path_context.py +++ b/python/resdata/util/test/path_context.py @@ -23,9 +23,8 @@ def __init__(self, path, store=False): break os.makedirs(path) - else: - if not self.store: - raise OSError("Entry %s already exists" % path) + elif not self.store: + raise OSError("Entry %s already exists" % path) os.chdir(path) def __exit__(self, exc_type, exc_val, exc_tb): diff --git a/python/resdata/util/test/resdata_test_runner.py b/python/resdata/util/test/resdata_test_runner.py index d6e0c60e6..6684b9dc9 100644 --- a/python/resdata/util/test/resdata_test_runner.py +++ b/python/resdata/util/test/resdata_test_runner.py @@ -19,7 +19,7 @@ def findTestsInDirectory(path, recursive=True, pattern="test*.py"): loader = TestLoader() test_suite = loader.discover(path, pattern=pattern) - for root, dirnames, filenames in os.walk(path): + for root, dirnames, _filenames in os.walk(path): for directory in dirnames: test_suite.addTests( ResdataTestRunner.findTestsInDirectory( diff --git a/python/resdata/util/test/source_enumerator.py b/python/resdata/util/test/source_enumerator.py index 66087830d..aca52b04a 100644 --- a/python/resdata/util/test/source_enumerator.py +++ b/python/resdata/util/test/source_enumerator.py @@ -6,7 +6,7 @@ class SourceEnumerator(object): @classmethod def removeComments(cls, code_string): code_string = re.sub( - re.compile("/\*.*?\*/", re.DOTALL), "", code_string + re.compile(r"/\*.*?\*/", re.DOTALL), "", code_string ) # remove all occurance streamed comments (/*COMMENT */) from string code_string = re.sub( re.compile("//.*?\n"), "", code_string @@ -20,7 +20,7 @@ def findEnum(cls, enum_name, full_source_file_path): text = SourceEnumerator.removeComments(text) - enum_pattern = re.compile("typedef\s+enum\s+\{(.*?)\}\s*(\w+?);", re.DOTALL) + enum_pattern = re.compile(r"typedef\s+enum\s+\{(.*?)\}\s*(\w+?);", re.DOTALL) for enum in enum_pattern.findall(text): if enum[1] == enum_name: @@ -32,7 +32,7 @@ def findEnum(cls, enum_name, full_source_file_path): def findEnumerators(cls, enum_name, source_file): enum_text = SourceEnumerator.findEnum(enum_name, source_file) - enumerator_pattern = re.compile("(\w+?)\s*?=\s*?(\d+)") + enumerator_pattern = re.compile(r"(\w+?)\s*?=\s*?(\d+)") enumerators = [] for enumerator in enumerator_pattern.findall(enum_text): diff --git a/python/resdata/util/test/test_run.py b/python/resdata/util/test/test_run.py index 59fad8ffe..6b9955e7c 100644 --- a/python/resdata/util/test/test_run.py +++ b/python/resdata/util/test/test_run.py @@ -17,7 +17,7 @@ class TestRun(object): default_ert_version = "stable" default_path_prefix = None - def __init__(self, config_file, args=[], name=None): + def __init__(self, config_file, args=[], name=None): # noqa: B006 if os.path.exists(config_file) and os.path.isfile(config_file): self.parseArgs(args) self.__ert_cmd = TestRun.default_ert_cmd diff --git a/python/resdata/util/util/__init__.py b/python/resdata/util/util/__init__.py index a3b20cb43..0b03cb91c 100644 --- a/python/resdata/util/util/__init__.py +++ b/python/resdata/util/util/__init__.py @@ -72,7 +72,7 @@ def __user_warning(msg): def __dev_warning(msg): - warnings.warn(msg, DeprecationWarning) + warnings.warn(msg, DeprecationWarning, stacklevel=1) def __hard_warning(msg): diff --git a/python/resdata/util/util/ctime.py b/python/resdata/util/util/ctime.py index b22114919..655cd70eb 100644 --- a/python/resdata/util/util/ctime.py +++ b/python/resdata/util/util/ctime.py @@ -3,6 +3,7 @@ import time from cwrap import BaseCValue + from resdata import ResdataPrototype @@ -16,7 +17,7 @@ class CTime(BaseCValue): def __init__(self, value): if isinstance(value, int): - value = value + pass elif isinstance(value, CTime): value = value.value() elif isinstance(value, datetime.datetime): diff --git a/python/resdata/util/util/lookup_table.py b/python/resdata/util/util/lookup_table.py index a1c6a4e2c..45fcb4657 100644 --- a/python/resdata/util/util/lookup_table.py +++ b/python/resdata/util/util/lookup_table.py @@ -1,4 +1,5 @@ from cwrap import BaseCClass + from resdata import ResdataPrototype @@ -101,12 +102,11 @@ def interp(self, x): "Interpolate argument:%g is outside valid interval: [%g,%g]" % (x, self.getMinArg(), self.getMaxArg()) ) - elif x > self.getMaxArg(): - if not self.hasUpperLimit(): - raise ValueError( - "Interpolate argument:%g is outside valid interval: [%g,%g]" - % (x, self.getMinArg(), self.getMaxArg()) - ) + elif x > self.getMaxArg() and not self.hasUpperLimit(): + raise ValueError( + "Interpolate argument:%g is outside valid interval: [%g,%g]" + % (x, self.getMinArg(), self.getMaxArg()) + ) return self._interp(x) diff --git a/python/resdata/util/util/thread_pool.py b/python/resdata/util/util/thread_pool.py index 52b94998c..bbfd3f6bc 100644 --- a/python/resdata/util/util/thread_pool.py +++ b/python/resdata/util/util/thread_pool.py @@ -1,7 +1,7 @@ import multiprocessing -from threading import Thread import time import traceback +from threading import Thread class Task(Thread): @@ -82,11 +82,7 @@ def taskCount(self): return len(self.__task_list) def __allTasksFinished(self): - for task in self.__task_list: - if not task.isDone(): - return False - - return True + return all(task.isDone() for task in self.__task_list) def runningCount(self): count = 0 @@ -143,8 +139,4 @@ def join(self): ) def hasFailedTasks(self): - for task in self.__task_list: - if task.hasFailed(): - return True - - return False + return any(task.hasFailed() for task in self.__task_list) diff --git a/python/resdata/util/util/time_vector.py b/python/resdata/util/util/time_vector.py index fbd49ee2e..0839a7596 100644 --- a/python/resdata/util/util/time_vector.py +++ b/python/resdata/util/util/time_vector.py @@ -120,16 +120,16 @@ def __init__(self, default_value=None, initial_size=0): else: try: default = CTime(default_value) - except: + except Exception as err: raise ValueError( "default value invalid - must be type ctime() or date/datetime" - ) + ) from err super(TimeVector, self).__init__(default, initial_size) @classmethod def parseTimeUnit(cls, deltaString): - deltaRegexp = re.compile("(?P\d*)(?P[dmy])", re.IGNORECASE) + deltaRegexp = re.compile(r"(?P\d*)(?P[dmy])", re.IGNORECASE) matchObj = deltaRegexp.match(deltaString) if matchObj: try: diff --git a/python/resdata/util/util/vector_template.py b/python/resdata/util/util/vector_template.py index 7b2e8a380..c950d25cd 100644 --- a/python/resdata/util/util/vector_template.py +++ b/python/resdata/util/util/vector_template.py @@ -61,10 +61,7 @@ def __bool__(self): """ Will evaluate to False for empty vector. """ - if len(self) == 0: - return False - else: - return True + return len(self) == 0 def __nonzero__(self): return self.__bool__() @@ -266,13 +263,12 @@ def __IADD(self, delta, add): "Incompatible sizes for add self:%d other:%d" % (len(self), len(delta)) ) + elif isinstance(delta, (int, float)): + if not add: + delta *= -1 + self._shift(delta) else: - if isinstance(delta, int) or isinstance(delta, float): - if not add: - delta *= -1 - self._shift(delta) - else: - raise TypeError("delta has wrong type:%s " % type(delta)) + raise TypeError("delta has wrong type:%s " % type(delta)) return self @@ -338,11 +334,10 @@ def __imul__(self, factor): "Incompatible sizes for mul self:%d other:%d" % (len(self), len(factor)) ) + elif isinstance(factor, (int, float)): + self._scale(factor) else: - if isinstance(factor, int) or isinstance(factor, float): - self._scale(factor) - else: - raise TypeError("factor has wrong type:%s " % type(factor)) + raise TypeError("factor has wrong type:%s " % type(factor)) return self @@ -355,7 +350,7 @@ def __rmul__(self, factor): return self.__mul__(factor) def __div__(self, divisor): - if isinstance(divisor, int) or isinstance(divisor, float): + if isinstance(divisor, (int, float)): copy = self._alloc_copy() copy._div(divisor) return copy @@ -396,11 +391,10 @@ def assign(self, value): if type(self) == type(value): # This is a copy operation self._memcpy(value) + elif isinstance(value, (int, float)): + self._assign(value) else: - if isinstance(value, int) or isinstance(value, float): - self._assign(value) - else: - raise TypeError("Value has wrong type") + raise TypeError("Value has wrong type") def __len__(self): """ diff --git a/python/resdata/util/util/version.py b/python/resdata/util/util/version.py index cd32d6e71..5482ee93d 100644 --- a/python/resdata/util/util/version.py +++ b/python/resdata/util/util/version.py @@ -91,11 +91,10 @@ def getBuildTime(self): def getGitCommit(self, short=False): if self.git_commit is None: return "???????" + elif short: + return self.git_commit[0:8] else: - if short: - return self.git_commit[0:8] - else: - return self.git_commit + return self.git_commit class ResdataVersion(Version): diff --git a/python/resdata/well/well_state.py b/python/resdata/well/well_state.py index f0962bd3d..cb10dfffa 100644 --- a/python/resdata/well/well_state.py +++ b/python/resdata/well/well_state.py @@ -163,10 +163,7 @@ def hasSegmentData(self): def __repr__(self): name = self.name() - if name: - name = "%s" % name - else: - name = "[no name]" + name = "%s" % name if name else "[no name]" msw = " (multi segment)" if self.isMultiSegmentWell() else "" wn = str(self.wellNumber()) type_ = self.wellType() diff --git a/python/tests/__init__.py b/python/tests/__init__.py index 709e3de3f..f83881a4e 100644 --- a/python/tests/__init__.py +++ b/python/tests/__init__.py @@ -61,14 +61,13 @@ def equinor_test(): """ def decorator(test_item): - if not isinstance(test_item, type): - if not ResdataTest.EQUINOR_DATA: + if not isinstance(test_item, type) and not ResdataTest.EQUINOR_DATA: - @functools.wraps(test_item) - def skip_wrapper(*args, **kwargs): - raise SkipTest("Missing Equinor testdata") + @wraps(test_item) + def skip_wrapper(*args, **kwargs): + raise SkipTest("Missing Equinor testdata") - test_item = skip_wrapper + test_item = skip_wrapper if not ResdataTest.EQUINOR_DATA: test_item.__unittest_skip__ = True diff --git a/python/tests/geometry_tests/test_surface.py b/python/tests/geometry_tests/test_surface.py index e23691a07..8d6fa0e58 100644 --- a/python/tests/geometry_tests/test_surface.py +++ b/python/tests/geometry_tests/test_surface.py @@ -55,8 +55,8 @@ def test_create_new(self): self.assertNotEqual(s, small) idx = 0 - for i in range(nx): - for j in range(ny): + for _i in range(nx): + for _j in range(ny): s[idx] = small[idx] idx += 1 self.assertEqual(s, small) @@ -164,7 +164,7 @@ def test_ops(self): def test_ops2(self): s0 = Surface(self.surface_small) surface_list = [] - for i in range(10): + for _i in range(10): s = s0.copy() for j in range(len(s)): s[j] = random.random() diff --git a/python/tests/rd_tests/test_fault_blocks.py b/python/tests/rd_tests/test_fault_blocks.py index 0fc5734c9..221a98bd6 100644 --- a/python/tests/rd_tests/test_fault_blocks.py +++ b/python/tests/rd_tests/test_fault_blocks.py @@ -359,7 +359,6 @@ def test_add_polyline_barrier2(self): ((4, 0), (4, 1)), ((6, 0), (6, 1)), ((8, 0), (8, 1)), - # ((8, 1), (9, 1)), ((8, 3), (9, 3)), ((8, 5), (9, 5)), diff --git a/python/tests/rd_tests/test_faults.py b/python/tests/rd_tests/test_faults.py index abf5dab6a..5135eee59 100644 --- a/python/tests/rd_tests/test_faults.py +++ b/python/tests/rd_tests/test_faults.py @@ -94,7 +94,7 @@ def test_empty_fault(self): f = Fault(self.grid, "NAME") self.assertEqual("NAME", f.getName()) - with self.assertRaises(Exception): + with self.assertRaises(KeyError): g = f["Key"] def test_empty_faultLine(self): @@ -409,7 +409,7 @@ def test_iter(self): faults = FaultCollection(self.grid, self.faults1, self.faults2) self.assertEqual(7, len(faults)) c = 0 - for f in faults: + for _f in faults: c += 1 self.assertEqual(c, len(faults)) diff --git a/python/tests/rd_tests/test_fortio.py b/python/tests/rd_tests/test_fortio.py index e2161a0df..8015faf78 100755 --- a/python/tests/rd_tests/test_fortio.py +++ b/python/tests/rd_tests/test_fortio.py @@ -57,9 +57,8 @@ def test_truncate(self): kw2.fwrite(f) # Truncate file in read mode; should fail hard. - with openFortIO("file") as f: - with self.assertRaises(IOError): - f.truncate() + with openFortIO("file") as f, self.assertRaises(IOError): + f.truncate() with openFortIO("file", mode=FortIO.READ_AND_WRITE_MODE) as f: f.seek(pos1) diff --git a/python/tests/rd_tests/test_geertsma.py b/python/tests/rd_tests/test_geertsma.py index 712541a65..de8864998 100644 --- a/python/tests/rd_tests/test_geertsma.py +++ b/python/tests/rd_tests/test_geertsma.py @@ -135,34 +135,6 @@ def test_geertsma_kernel_seabed(): ) np.testing.assert_almost_equal(dz, 5.819790154474284e-08) - @staticmethod - def test_geertsma_kernel_seabed(): - grid = Grid.createRectangular(dims=(1, 1, 1), dV=(50, 50, 50)) - with TestAreaContext("Subsidence"): - p1 = [1] - create_restart(grid, "TEST", p1) - create_init(grid, "TEST") - - init = ResdataFile("TEST.INIT") - restart_file = ResdataFile("TEST.UNRST") - - restart_view1 = restart_file.restartView(sim_time=datetime.date(2000, 1, 1)) - - subsidence = ResdataSubsidence(grid, init) - subsidence.add_survey_PRESSURE("S1", restart_view1) - - youngs_modulus = 5e8 - poisson_ratio = 0.3 - seabed = 300 - above = 100 - topres = 2000 - receiver = (1000, 1000, topres - seabed - above) - - dz = subsidence.evalGeertsma( - "S1", None, receiver, youngs_modulus, poisson_ratio, seabed - ) - np.testing.assert_almost_equal(dz, 5.819790154474284e-08) - def test_geertsma_rporv_kernel_2_source_points_2_vintages(self): grid = Grid.createRectangular(dims=(2, 1, 1), dV=(100, 100, 100)) diff --git a/python/tests/rd_tests/test_grid.py b/python/tests/rd_tests/test_grid.py index 4d8e734c1..09529f5bb 100644 --- a/python/tests/rd_tests/test_grid.py +++ b/python/tests/rd_tests/test_grid.py @@ -306,7 +306,7 @@ def test_all_iters(self): self.assertTrue(c.active) self.assertEqual(cnt, 4160) - cnt = len([c for c in grid.cells()]) + cnt = len(list(grid.cells())) self.assertEqual(cnt, len(grid)) def test_repr_and_name(self): diff --git a/python/tests/rd_tests/test_grid_equinor.py b/python/tests/rd_tests/test_grid_equinor.py index d433fa8a3..a8ecc05c2 100755 --- a/python/tests/rd_tests/test_grid_equinor.py +++ b/python/tests/rd_tests/test_grid_equinor.py @@ -171,16 +171,15 @@ def test_grdecl_load(self): f2.write("SPECGRID\n") f2.write(" 10 10 10 'F' /\n") - with openResdataFile("G.EGRID") as f: - with copen("grid.grdecl", "a") as f2: - coord_kw = f["COORD"][0] - coord_kw.write_grdecl(f2) + with openResdataFile("G.EGRID") as f, copen("grid.grdecl", "a") as f2: + coord_kw = f["COORD"][0] + coord_kw.write_grdecl(f2) - zcorn_kw = f["ZCORN"][0] - zcorn_kw.write_grdecl(f2) + zcorn_kw = f["ZCORN"][0] + zcorn_kw.write_grdecl(f2) - actnum_kw = f["ACTNUM"][0] - actnum_kw.write_grdecl(f2) + actnum_kw = f["ACTNUM"][0] + actnum_kw.write_grdecl(f2) g2 = Grid.loadFromGrdecl("grid.grdecl") self.assertTrue(g1.equal(g2)) @@ -252,7 +251,7 @@ def test_boundingBox(self): def test_num_active_large_memory(self): case = self.createTestPath("Equinor/ECLIPSE/Gurbat/ECLIPSE") vecList = [] - for i in range(12500): + for _i in range(12500): vec = DoubleVector() vec[81920] = 0 vecList.append(vec) diff --git a/python/tests/rd_tests/test_grid_generator.py b/python/tests/rd_tests/test_grid_generator.py index 4c261b643..808f4fc45 100644 --- a/python/tests/rd_tests/test_grid_generator.py +++ b/python/tests/rd_tests/test_grid_generator.py @@ -64,15 +64,15 @@ def test_extract_grid_decomposition_change(self): coord = list(GridGen.create_coord(dims, (1, 1, 1))) ijk_bounds = generate_ijk_bounds(dims) - for ijk_bounds in ijk_bounds: - if decomposition_preserving(ijk_bounds): - GridGen.extract_subgrid_data(dims, coord, zcorn, ijk_bounds) + for ijk_bound in ijk_bounds: + if decomposition_preserving(ijk_bound): + GridGen.extract_subgrid_data(dims, coord, zcorn, ijk_bound) else: with self.assertRaises(ValueError): - GridGen.extract_subgrid_data(dims, coord, zcorn, ijk_bounds) + GridGen.extract_subgrid_data(dims, coord, zcorn, ijk_bound) GridGen.extract_subgrid_data( - dims, coord, zcorn, ijk_bounds, decomposition_change=True + dims, coord, zcorn, ijk_bound, decomposition_change=True ) def test_extract_grid_invalid_bounds(self): diff --git a/python/tests/rd_tests/test_npv.py b/python/tests/rd_tests/test_npv.py index d69fd6026..ee4f23932 100644 --- a/python/tests/rd_tests/test_npv.py +++ b/python/tests/rd_tests/test_npv.py @@ -39,7 +39,7 @@ def setUp(self): self.case = self.createTestPath(case) def test_create(self): - with self.assertRaises(Exception): + with self.assertRaises(OSError): npv = ResdataNPV("/does/not/exist") npv = ResdataNPV(self.case) diff --git a/python/tests/rd_tests/test_rd_cmp.py b/python/tests/rd_tests/test_rd_cmp.py index 13f1d2029..152cdb762 100644 --- a/python/tests/rd_tests/test_rd_cmp.py +++ b/python/tests/rd_tests/test_rd_cmp.py @@ -43,7 +43,7 @@ def test_wells(self): rd_cmp = ResdataCmp(self.root1, self.root1) wells = rd_cmp.testWells() - well_set = set(["OP_1", "OP_2", "OP_3", "OP_4", "OP_5", "WI_1", "WI_2", "WI_3"]) + well_set = {"OP_1", "OP_2", "OP_3", "OP_4", "OP_5", "WI_1", "WI_2", "WI_3"} self.assertEqual(len(wells), len(well_set)) for well in wells: self.assertTrue(well in well_set) diff --git a/python/tests/rd_tests/test_rd_kw.py b/python/tests/rd_tests/test_rd_kw.py index 4ff7b17e8..b8f86f2e1 100644 --- a/python/tests/rd_tests/test_rd_kw.py +++ b/python/tests/rd_tests/test_rd_kw.py @@ -68,10 +68,9 @@ def kw_test(self, data_type, data, fmt): kw.fprintf_data(file1, fmt) file1.close() - file2 = open(name2, "w") - for d in data: - file2.write(fmt % d) - file2.close() + with open(name2, "w") as file2: + for d in data: + file2.write(fmt % d) self.assertFilesAreEqual(name1, name2) self.assertEqual(kw.data_type, data_type) @@ -141,10 +140,8 @@ def test_kw_write(self): data = [random.random() for i in range(10000)] kw = ResdataKW("TEST", len(data), ResDataType.RD_DOUBLE) - i = 0 - for d in data: + for i, d in enumerate(data): kw[i] = d - i += 1 pfx = "ResdataKW(" self.assertEqual(pfx, repr(kw)[: len(pfx)]) @@ -186,12 +183,12 @@ def test_fprintf_data(self): kw.fprintf_data(fileH) fileH.close() - fileH = open("test", "r") - data = [] - for line in fileH.readlines(): - tmp = line.split() - for elm in tmp: - data.append(int(elm)) + with open("test", "r") as fileH: + data = [] + for line in fileH.readlines(): + tmp = line.split() + for elm in tmp: + data.append(int(elm)) for v1, v2 in zip(data, kw): self.assertEqual(v1, v2) diff --git a/python/tests/rd_tests/test_rd_sum.py b/python/tests/rd_tests/test_rd_sum.py index e0dac8b7d..26b537e77 100644 --- a/python/tests/rd_tests/test_rd_sum.py +++ b/python/tests/rd_tests/test_rd_sum.py @@ -7,6 +7,7 @@ from cwrap import open as copen from hypothesis import assume, given from pandas.testing import assert_frame_equal + from resdata.resfile import FortIO, ResdataKW, openFortIO, openResdataFile from resdata.summary import Summary, SummaryKeyWordVector from resdata.util.test import TestAreaContext @@ -135,9 +136,8 @@ def test_missing_unsmry_keyword(self): with openFortIO("ECLIPSE.UNSMRY", mode=FortIO.WRITE_MODE) as f: c = 0 for kw in kw_list: - if kw.getName() == "PARAMS": - if c % 5 == 0: - continue + if kw.getName() == "PARAMS" and c % 5 == 0: + continue c += 1 kw.fwrite(f) diff --git a/python/tests/rd_tests/test_rd_type.py b/python/tests/rd_tests/test_rd_type.py index adad89d9a..910f14c11 100644 --- a/python/tests/rd_tests/test_rd_type.py +++ b/python/tests/rd_tests/test_rd_type.py @@ -120,7 +120,7 @@ def test_equals(self): self.assertTrue(a.is_equal(b)) self.assertEqual(a, b) - for otype, osize in set(test_base) - set([(rd_type, elem_size)]): + for otype, osize in set(test_base) - set(rd_type, elem_size): self.assertFalse(a.is_equal(ResDataType(otype, osize))) self.assertNotEqual(a, ResDataType(otype, osize)) @@ -132,10 +132,10 @@ def test_hash(self): all_types.add(ResDataType(rd_type, elem_size)) self.assertEqual(index + 1, len(all_types)) - for index, (rd_type, elem_size) in enumerate(test_base): + for _index, (rd_type, elem_size) in enumerate(test_base): all_types.add(ResDataType(rd_type, elem_size)) - for index, rd_type in enumerate(get_const_size_types()): + for _index, rd_type in enumerate(get_const_size_types()): all_types.add(ResDataType(rd_type)) self.assertEqual(len(test_base), len(all_types)) diff --git a/python/tests/rd_tests/test_rd_util.py b/python/tests/rd_tests/test_rd_util.py index 7343abb10..915624546 100644 --- a/python/tests/rd_tests/test_rd_util.py +++ b/python/tests/rd_tests/test_rd_util.py @@ -41,7 +41,7 @@ def test_get_start_date_reads_from_start_kw_in_data_file(tmp_path): data_file = tmp_path / "dfile" data_file.write_text( dedent( - f"""\ + """\ START 4 Apr 2024 / """ diff --git a/python/tests/rd_tests/test_region_equinor.py b/python/tests/rd_tests/test_region_equinor.py index dc84dd95c..2f58afd38 100755 --- a/python/tests/rd_tests/test_region_equinor.py +++ b/python/tests/rd_tests/test_region_equinor.py @@ -126,7 +126,7 @@ def test_slice(self): OK = False self.assertTrue(OK) - self.assertTrue(2 * 3 * 6 == len(reg.getGlobalList())) + self.assertTrue(len(reg.getGlobalList()) == 2 * 3 * 6) def test_index_list(self): reg = ResdataRegion(self.grid, False) diff --git a/python/tests/rd_tests/test_sum.py b/python/tests/rd_tests/test_sum.py index 198f4627e..9fdc0239f 100644 --- a/python/tests/rd_tests/test_sum.py +++ b/python/tests/rd_tests/test_sum.py @@ -4,22 +4,10 @@ import os.path import shutil import stat -import datetime +from contextlib import contextmanager import cwrap - - -def assert_frame_equal(a, b): - if not a.equals(b): - raise AssertionError("Expected dataframes to be equal") - - -try: - from pandas.testing import assert_frame_equal -except ImportError: - pass - -from contextlib import contextmanager +from pandas.testing import assert_frame_equal from resdata import ResDataType, UnitSystem from resdata.resfile import FortIO, ResdataFile, ResdataKW, openFortIO @@ -204,7 +192,7 @@ def test_csv_export(self): with TestAreaContext("resdata/csv"): case.exportCSV("file.csv", sep=sep) self.assertTrue(os.path.isfile("file.csv")) - input_file = csv.DictReader(open("file.csv"), delimiter=sep) + input_file = csv.DictReader(open("file.csv"), delimiter=sep) # noqa: SIM115 for row in input_file: self.assertIn("DAYS", row) self.assertIn("DATE", row) @@ -218,7 +206,7 @@ def test_csv_export(self): with TestAreaContext("resdata/csv"): case.exportCSV("file.csv", keys=["FOPT"], sep=sep) self.assertTrue(os.path.isfile("file.csv")) - input_file = csv.DictReader(open("file.csv"), delimiter=sep) + input_file = csv.DictReader(open("file.csv"), delimiter=sep) # noqa: SIM115 for row in input_file: self.assertIn("DAYS", row) self.assertIn("DATE", row) @@ -659,8 +647,8 @@ def test_load_case(self): self.assertFloatEqual(case.sim_length, 545.0) fopr = case.numpy_vector("FOPR") - for time_index, value in enumerate(fopr): - self.assertEqual(fopr[time_index], value) + for _time_index, value in enumerate(fopr): + self.assertEqual(value, value) def test_load_case_lazy_and_eager(self): path = os.path.join(self.TESTDATA_ROOT, "local/ECLIPSE/cp_simple3/SHORT.UNSMRY") @@ -718,7 +706,7 @@ def test_resample_extrapolate(self): upper_extrapolation=True, ) - for key in rd_sum.keys(): + for key in rd_sum: self.assertIn(key, resampled) self.assertEqual( @@ -746,9 +734,7 @@ def test_resample_extrapolate(self): key_rate = "FOPR" for time_index, t in enumerate(time_points): - if t < rd_sum.get_data_start_time(): - self.assertFloatEqual(resampled.iget(key_rate, time_index), 0) - elif t > rd_sum.get_end_time(): + if t < rd_sum.get_data_start_time() or t > rd_sum.get_end_time(): self.assertFloatEqual(resampled.iget(key_rate, time_index), 0) else: self.assertFloatEqual( diff --git a/python/tests/rd_tests/test_sum_equinor.py b/python/tests/rd_tests/test_sum_equinor.py index 0f46400dd..17d05d2d3 100755 --- a/python/tests/rd_tests/test_sum_equinor.py +++ b/python/tests/rd_tests/test_sum_equinor.py @@ -109,17 +109,17 @@ def test_wells(self): wells = self.rd_sum.wells() wells.sort() self.assertListEqual( - [well for well in wells], + list(wells), ["OP_1", "OP_2", "OP_3", "OP_4", "OP_5", "WI_1", "WI_2", "WI_3"], ) wells = self.rd_sum.wells(pattern="*_3") wells.sort() - self.assertListEqual([well for well in wells], ["OP_3", "WI_3"]) + self.assertListEqual(list(wells), ["OP_3", "WI_3"]) groups = self.rd_sum.groups() groups.sort() - self.assertListEqual([group for group in groups], ["GMWIN", "OP", "WI"]) + self.assertListEqual(list(groups), ["GMWIN", "OP", "WI"]) def test_last(self): last = self.rd_sum.get_last("FOPT") @@ -283,7 +283,7 @@ def test_stringlist_reference(self): sum = Summary(self.case) wells = sum.wells() self.assertListEqual( - [well for well in wells], + list(wells), ["OP_1", "OP_2", "OP_3", "OP_4", "OP_5", "WI_1", "WI_2", "WI_3"], ) self.assertIsInstance(wells, StringList) @@ -475,10 +475,7 @@ def test_ix_case(self): self.assertTrue( "HWELL_PROD" - in [ - intersect_summary.smspec_node(key).wgname - for key in intersect_summary.keys() - ] + in [intersect_summary.smspec_node(key).wgname for key in intersect_summary] ) eclipse_summary = Summary( @@ -538,7 +535,7 @@ def test_resample(self): time_points.append(CTime(end_time)) resampled = self.rd_sum.resample("OUTPUT_CASE", time_points) - for key in self.rd_sum.keys(): + for key in self.rd_sum: self.assertIn(key, resampled) self.assertEqual( diff --git a/python/tests/test_bin.py b/python/tests/test_bin.py index 40fcf5f29..976b75b20 100644 --- a/python/tests/test_bin.py +++ b/python/tests/test_bin.py @@ -30,6 +30,6 @@ ], ) def test_exec(name: str, returncode: int, stderr: str) -> None: - status = subprocess.run([name], stderr=subprocess.PIPE) + status = subprocess.run([name], stderr=subprocess.PIPE, check=False) assert status.returncode == returncode assert stderr in status.stderr diff --git a/python/tests/util_tests/test_path_context.py b/python/tests/util_tests/test_path_context.py index acd181ddc..0723be0cb 100644 --- a/python/tests/util_tests/test_path_context.py +++ b/python/tests/util_tests/test_path_context.py @@ -1,4 +1,5 @@ import os + from resdata.util.test import PathContext, TestAreaContext from tests import ResdataTest @@ -8,16 +9,14 @@ def test_error(self): with TestAreaContext("pathcontext"): # Test failure on creating PathContext with an existing path os.makedirs("path/1") - with self.assertRaises(OSError): - with PathContext("path/1"): - pass + with self.assertRaises(OSError), PathContext("path/1"): + pass # Test failure on creating PathContext with an existing file with open("path/1/file", "w") as f: f.write("xx") - with self.assertRaises(OSError): - with PathContext("path/1/file"): - pass + with self.assertRaises(OSError), PathContext("path/1/file"): + pass def test_chdir(self): with PathContext("/tmp/pc"): @@ -27,9 +26,8 @@ def test_cleanup(self): with TestAreaContext("pathcontext"): os.makedirs("path/1") - with PathContext("path/1/next/2/level"): - with open("../../file", "w") as f: - f.write("Crap") + with PathContext("path/1/next/2/level"), open("../../file", "w") as f: + f.write("Crap") self.assertTrue(os.path.isdir("path/1")) self.assertTrue(os.path.isdir("path/1/next")) diff --git a/python/tests/util_tests/test_string_list.py b/python/tests/util_tests/test_string_list.py index a2aa344df..2462f1536 100644 --- a/python/tests/util_tests/test_string_list.py +++ b/python/tests/util_tests/test_string_list.py @@ -73,7 +73,7 @@ def test_last(self): s.pop() s.pop() s.pop() - s.last + _ = s.last def test_in_and_not_in(self): s = StringList(["A", "list", "of", "strings"]) diff --git a/python/tests/util_tests/test_thread_pool.py b/python/tests/util_tests/test_thread_pool.py index 8e3b2aa1a..eb99a9bec 100644 --- a/python/tests/util_tests/test_thread_pool.py +++ b/python/tests/util_tests/test_thread_pool.py @@ -61,7 +61,7 @@ def test_pool_unbound_fail(self): def test_fill_pool(self): pool = ThreadPool(4) - for index in range(10): + for _index in range(10): pool.addTask(self.sleepTask, 2) pool.nonBlockingStart() diff --git a/python/tests/util_tests/test_vectors.py b/python/tests/util_tests/test_vectors.py index ea5e73faa..9d1e9d9d2 100755 --- a/python/tests/util_tests/test_vectors.py +++ b/python/tests/util_tests/test_vectors.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import copy import datetime + import six try: @@ -523,16 +524,16 @@ def range_test(self, v, a, b, d): r = range(a, b, d) self.assertEqual(len(v), len(r)) - for a, b in zip(v, r): - self.assertEqual(a, b) + for val1, val2 in zip(v, r): + self.assertEqual(val1, val2) def create_range_test(self, v, a, b, d): v = IntVector.createRange(a, b, d) r = range(a, b, d) self.assertEqual(len(v), len(r)) - for a, b in zip(v, r): - self.assertEqual(a, b) + for val1, val2 in zip(v, r): + self.assertEqual(val1, val2) def test_range(self): v = IntVector() diff --git a/python/tests/well_tests/test_rd_well.py b/python/tests/well_tests/test_rd_well.py index 8bb549e17..e37ba73cd 100644 --- a/python/tests/well_tests/test_rd_well.py +++ b/python/tests/well_tests/test_rd_well.py @@ -376,7 +376,7 @@ def test_well_info(self): self.assertNotEqual(well_info[0], well_info[1]) - well_time_lines = [wtl for wtl in well_info] + well_time_lines = list(well_info) self.assertEqual(len(well_time_lines), len(well_info)) with self.assertRaises(IndexError): diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index c62373ce9..000000000 --- a/setup.cfg +++ /dev/null @@ -1,5 +0,0 @@ -[aliases] -test = pytest - -[flake8] -max-line-length = 88