Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add spline model #321

Draft
wants to merge 25 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
40daec8
add spline model
ndem0 Aug 1, 2024
d580710
correction of orders and knot numbers
AleDinve Aug 20, 2024
1210c19
Change default reduction in SystemEquation (#317)
luAndre00 Aug 1, 2024
98d0fef
Add layer to perform RBF interpolation in reduced order modeling (#315)
annaivagnes Aug 12, 2024
536c5f9
Fix #316 (#324)
dario-coscia Aug 12, 2024
55efcfe
* Adding a test for all PINN solvers to assert that the metrics are c…
dario-coscia Aug 6, 2024
fafbda0
* Fixing mean tracked loss
dario-coscia Aug 6, 2024
c9f123d
Format Python code with psf/black push (#325)
github-actions[bot] Aug 12, 2024
16b8a39
starting doc
dario-coscia Jul 23, 2024
23f29d2
Added Contributing, Cite PINA, License to the navigation bar
Jul 25, 2024
6099e38
Logo color and font, adding PINA twitter
Jul 26, 2024
2c0277a
minor changes, version edited
Jul 27, 2024
e7f6e2c
deleting old logo
Aug 2, 2024
888a93a
deleted sphinx_rtd_theme from the extensions
Aug 2, 2024
d521a33
Gmail icon added, PINA team page updated
Aug 19, 2024
5ce6c03
add OrthogonalBlock to make input orthonormal
annaivagnes Aug 16, 2024
eb07150
:art: Format Python code with psf/black
ndem0 Aug 26, 2024
ace6565
Update orthogonal.py
ndem0 Aug 26, 2024
4205cb9
:art: Format Python code with psf/black
ndem0 Aug 26, 2024
033b0b7
Backpropagation and fix test for OrthogonalBlock
dario-coscia Sep 3, 2024
1877a45
:art: Format Python code with psf/black (#333)
github-actions[bot] Sep 3, 2024
6ca5c83
fixing documentation issues
dario-coscia Sep 3, 2024
38753c5
update doc (#335)
dario-coscia Sep 3, 2024
270d127
Version 0.1.2 (#336)
dario-coscia Sep 3, 2024
cd7d357
solving splines bugs
AleDinve Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
solving splines bugs
  • Loading branch information
AleDinve committed Sep 5, 2024
commit cd7d35707f2dd50705c96cf1b8808fbb63f1df6b
33 changes: 15 additions & 18 deletions pina/model/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,26 @@

class Spline(torch.nn.Module):

def __init__(self, order, knots=None, control_points=None) -> None:
def __init__(self, order=4, knots=None, control_points=None) -> None:
"""
Spline model.

:param int order: the order of the spline.

"""
super().__init__()

check_consistency(order, int)
if order < 0:
raise ValueError("Spline order cannot be negative.")
if knots is None and control_points is None:
raise ValueError("Knots and control points cannot be both None.")

self.order = order
self.k = order-1

if knots is not None:
self.knots = knots
n = len(knots) - order - 1
n = len(knots) - order
self.control_points = torch.nn.Parameter(
torch.zeros(n), requires_grad=True)

Expand All @@ -36,13 +36,12 @@ def __init__(self, order, knots=None, control_points=None) -> None:
'type': 'auto',
'min': 0,
'max': 1,
'n': n + order + 1}
'n': n}

else:
self.knots = knots
self.control_points = control_points

print(self.knots)

if self.knots.ndim != 1:
raise ValueError("Knot vector must be one-dimensional.")
Expand Down Expand Up @@ -93,7 +92,7 @@ def __init__(self, order, knots=None, control_points=None) -> None:
# plt.show()

@staticmethod
def B(x, k, i, t, order):
def B(x, k, i, t):
'''
x: points to be evaluated
k: spline degree
Expand All @@ -102,21 +101,21 @@ def B(x, k, i, t, order):
'''
if k == 0:
a = torch.where(torch.logical_and(t[i] <= x, x < t[i+1]), 1.0, 0.0)
if i == len(t) - order - 1:
a = torch.where(x == t[-1], 1.0, a)
if i == len(t) - k - 1:
a = torch.where(x == t[-1], 1.0, a)
a.requires_grad_(True)
return a


if t[i+k] == t[i]:
c1 = torch.tensor([0.0], requires_grad=True)
else:
c1 = (x - t[i])/(t[i+k] - t[i]) * Spline.B(x, k-1, i, t, order=order)
c1 = (x - t[i])/(t[i+k] - t[i]) * Spline.B(x, k-1, i, t)

if t[i+k+1] == t[i+1]:
c2 = torch.tensor([0.0], requires_grad=True)
else:
c2 = (t[i+k+1] - x)/(t[i+k+1] - t[i+1]) * Spline.B(x, k-1, i+1, t, order=order)
c2 = (t[i+k+1] - x)/(t[i+k+1] - t[i+1]) * Spline.B(x, k-1, i+1, t)

return c1 + c2

Expand All @@ -134,7 +133,6 @@ def control_points(self, value):
dim = value.get('dim', 1)
value = torch.zeros(n, dim)

print(value)
if not isinstance(value, torch.Tensor):
raise ValueError('Invalid value for control_points')

Expand All @@ -154,14 +152,13 @@ def knots(self, value):
n = value.get('n', 10)

if type_ == 'uniform':
value = torch.linspace(min_, max_, n + k + 1)
value = torch.linspace(min_, max_, n + self.k + 1)
elif type_ == 'auto':
k = self.order - 1
value = torch.concatenate(
(
torch.ones(k)*min_,
torch.linspace(min_, max_, n - k +1),
torch.ones(k)*max_,
torch.ones(self.k)*min_,
torch.linspace(min_, max_, n - self.k +1),
torch.ones(self.k)*max_,
# [self.max] * (k-1)
)
)
Expand All @@ -173,7 +170,7 @@ def knots(self, value):

def forward(self, x_):
t = self.knots
k = self.order
k = self.k
c = self.control_points

# return LabelTensor((x_**2).reshape(-1, 1), ['v'])
Expand All @@ -186,7 +183,7 @@ def forward(self, x_):
# print(self.B(x_, k, i, t), c[i])
# print(self.B(x, k, i, t), c[i])
tmp_result = torch.concatenate([
(c[i] * Spline.B(x_, k, i, t, self.order)).reshape(
(c[i] * Spline.B(x_, k, i, t)).reshape(
1, x_.shape[0], -1)
for i in range(len(c))], axis=0
)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_model/test_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@ def scipy_check(model, x, y):
spline = BSpline(
t=model.knots.detach(),
c=model.control_points.detach(),
k=model.order
k=model.order-1
)
y_numpy = spline(x)
y_scipy = spline(x)
y = y.detach().numpy()
print(y)
np.testing.assert_allclose(y, y_numpy, atol=1e-5)
np.testing.assert_allclose(y, y_scipy, atol=1e-5)

def test_constructor():
Spline()
Expand All @@ -37,9 +36,10 @@ def test_constructor_wrong():
@pytest.mark.parametrize("args", valid_args)
def test_forward(args):
xi = torch.linspace(0, 1, 100)
model = Spline(**args)
model = Spline(*args)
yi = model(xi).squeeze()
scipy_check(model, xi, yi)
return


def test_backward():
Expand Down
Loading