Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions autoregressive_codice_prova.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import torch
import matplotlib.pyplot as plt

from pina import Trainer
from pina.optim import TorchOptimizer
from pina.problem import AbstractProblem
from pina.condition.data_condition import DataCondition
from pina.solver import AutoregressiveSolver

NUM_TIMESTEPS = 100
NUM_FEATURES = 15
USE_TEST_MODEL = False

# ============================================================================
# DATA
# ============================================================================

torch.manual_seed(42)

y = torch.zeros(NUM_TIMESTEPS, NUM_FEATURES)
y[0] = torch.rand(NUM_FEATURES) # Random initial state

for t in range(NUM_TIMESTEPS - 1):
y[t + 1] = 0.95 * y[t] # + 0.05 * torch.sin(y[t].sum())

# ============================================================================
# TRAINING
# ============================================================================

class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(y.shape[1], 15),
torch.nn.Tanh(),
# torch.nn.Dropout(0.1),
torch.nn.Linear(15, y.shape[1]),
)

def forward(self, x):
return x + self.layers(x)


class TestModel(torch.nn.Module):
"""
Debug model that implements the EXACT transformation rule.
y[t+1] = 0.95 * y[t]
Expected loss is zero
"""

def __init__(self, data_series=None):
super().__init__()
self.dummy_param = torch.nn.Parameter(torch.zeros(1))

def forward(self, x):
next_state = 0.95 * x # + 0.05 * torch.sin(x.sum(dim=1, keepdim=True))
return next_state + 0.0 * self.dummy_param

# create a problem with duplicated data conditions
class Problem(AbstractProblem):
output_variables = None
input_variables = None

# create two different unroll datasets: short and medium
y_short = AutoregressiveSolver.unroll(
y, unroll_length=4, num_unrolls=20, randomize=False
)
y_medium = AutoregressiveSolver.unroll(
y, unroll_length=10, num_unrolls=15, randomize=False
)
y_long = AutoregressiveSolver.unroll(
y, unroll_length=20, num_unrolls=10, randomize=False
)

conditions = {}

inactive_conditions = {
"short": DataCondition(input=y_short),
"medium": DataCondition(input=y_medium),
"long": DataCondition(input=y_long),
}

# Settings kept separate from the DataCondition objects
conditions_settings = {
"short": {"eps": 0.1},
"medium": {"eps": 1.0},
"long": {"eps": 2.0},
}


problem = Problem()

# helper that allows to activate or replace a condition at runtime
def activate_condition(problem, name, data=None, settings=None):
"""
Activate a single condition by name.

`conditions_settings` is left untouched unless `settings` is explicitly
provided and no entry exists yet for `name`.
"""
# if data is provided, (re)register condition in inactive store
if data is not None:
problem.inactive_conditions[name] = DataCondition(input=data)

problem.conditions = {}
problem.conditions[name] = problem.inactive_conditions[name]

if settings is not None:
problem.conditions_settings[name] = settings

# configure solver and trainer
solver = AutoregressiveSolver(
problem=problem,
model=TestModel() if USE_TEST_MODEL else SimpleModel(),
optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.011),
)


print("Beginning phase 1: training with 'short' condition only")
activate_condition(problem, "short")
trainer1 = Trainer(solver, max_epochs=300, accelerator="cpu", enable_model_summary=False)
trainer1.train()

print("Beginning phase 2: training with 'medium' condition added")
activate_condition(problem, "medium")
trainer2 = Trainer(solver, max_epochs=500, accelerator="cpu", enable_model_summary=False)
trainer2.train()

print("Beginning phase 3: training with 'long' condition added")
activate_condition(problem, "long")
trainer3 = Trainer(solver, max_epochs=900, accelerator="cpu", enable_model_summary=False)
trainer3.train()


# ============================================================================
test_start_idx = 50
num_prediction_steps = 49
initial_state = y[test_start_idx] # Shape: [features]
predictions = solver.predict(initial_state, num_prediction_steps)
actual = y[test_start_idx : test_start_idx + num_prediction_steps + 1]

print("\n=== PREDICTION DEBUG ===")
for i in range(min(10, num_prediction_steps)):
pred_val = predictions[i].mean().item()
actual_val = actual[i].mean().item()
error = (predictions[i] - actual[i]).abs().mean().item()
print(f"Step {i}: pred={pred_val:.4f}, actual={actual_val:.4f}, error={error:.4f}")

total_mse = torch.nn.functional.mse_loss(predictions[1:], actual[1:])
print(f"\nOverall MSE (all {num_prediction_steps} steps): {total_mse:.6f}")

# visualize single dof
dof_to_plot = [0, 3, 6, 9, 12]
colors = [
"r",
"g",
"b",
"c",
"m",
"y",
"k",
]
plt.figure(figsize=(10, 6))
for dof, color in zip(dof_to_plot, colors):
plt.plot(
range(test_start_idx, test_start_idx + num_prediction_steps + 1),
actual[:, dof].numpy(),
label="Actual",
marker="o",
color=color,
markerfacecolor="none",
)
plt.plot(
range(test_start_idx, test_start_idx + num_prediction_steps + 1),
predictions[:, dof].numpy(),
label="Predicted",
marker="x",
color=color,
)

plt.title(f"Autoregressive Predictions vs Actual, MRSE: {total_mse:.6f}")
plt.legend()
plt.xlabel("Timestep")
plt.savefig(f"autoregressive_predictions.png")
plt.close()
5 changes: 5 additions & 0 deletions pina/solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"DeepEnsembleSupervisedSolver",
"DeepEnsemblePINN",
"GAROM",
"AutoregressiveSolver",
]

from .solver import SolverInterface, SingleSolverInterface, MultiSolverInterface
Expand All @@ -41,3 +42,7 @@
DeepEnsemblePINN,
)
from .garom import GAROM
from .autoregressive_solver import (
AutoregressiveSolver,
AutoregressiveSolverInterface,
)
4 changes: 4 additions & 0 deletions pina/solver/autoregressive_solver/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
__all__ = ["AutoregressiveSolver", "AutoregressiveSolverInterface"]

from .autoregressive_solver import AutoregressiveSolver
from .autoregressive_solver_interface import AutoregressiveSolverInterface
Loading