Skip to content

Commit

Permalink
Mypy type safety: round 2
Browse files Browse the repository at this point in the history
  • Loading branch information
TheBB committed Feb 12, 2024
1 parent 253d641 commit 970268e
Show file tree
Hide file tree
Showing 6 changed files with 326 additions and 196 deletions.
10 changes: 5 additions & 5 deletions splipy/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from bisect import bisect_left, bisect_right
from typing import Optional, Any, Sequence, Union, overload, cast
from typing import Optional, Any, Sequence, Union, overload, cast, Callable

import numpy as np
import scipy.sparse.linalg as splinalg
Expand Down Expand Up @@ -449,7 +449,7 @@ def rebuild(self, p: int, n: int) -> Curve:
# return new resampled curve
return Curve(basis, controlpoints)

def error(self, target: Curve) -> tuple[list[float], float]:
def error(self, target: Callable) -> tuple[list[float], float]:
"""Computes the L2 (squared and per knot span) and max error between
this curve and a target curve
Expand Down Expand Up @@ -480,11 +480,11 @@ def arclength_circle(t):
"""
knots = self.knots(0)
(x,w) = np.polynomial.legendre.leggauss(self.order(0)+1)
err2 = []
err2 = []
err_inf = 0.0
for t0,t1 in zip(knots[:-1], knots[1:]): # for all knot spans
tg = (x+1)/2*(t1-t0)+t0 # evaluation points
wg = w /2*(t1-t0) # integration weights
tg = (x+1)/2 * (t1-t0) + t0 # evaluation points
wg = w/2 * (t1-t0) # integration weights
error = self(tg) - target(tg) # [x-xh, y-yh, z-zh]
error = np.sum(error**2, axis=1) # |x-xh|^2
err2.append(np.dot(error, wg)) # integrate over domain
Expand Down
Loading

0 comments on commit 970268e

Please sign in to comment.