Module 13 Discussion

Author

Robert Jenkins

Setup

library(fpp3)
library(fredr)
library(dplyr)
library(tseries)
library(fabletools)
library(ggplot2)
library(readr)
library(tidyverse)
library(tsibble)
library(feasts)
library(scales)
library(patchwork)
library(ggtime)
library(tseries)
library(readxl)
library(lubridate)
library(forecast)
library(tidyquant)
library(tidyverse)
library(quantmod)
library(lubridate)
fredr_has_key()
[1] TRUE

Data

data("vic_elec")
head(vic_elec)
# A tsibble: 6 x 5 [30m] <Australia/Melbourne>
  Time                Demand Temperature Date       Holiday
  <dttm>               <dbl>       <dbl> <date>     <lgl>  
1 2012-01-01 00:00:00  4383.        21.4 2012-01-01 TRUE   
2 2012-01-01 00:30:00  4263.        21.0 2012-01-01 TRUE   
3 2012-01-01 01:00:00  4049.        20.7 2012-01-01 TRUE   
4 2012-01-01 01:30:00  3878.        20.6 2012-01-01 TRUE   
5 2012-01-01 02:00:00  4036.        20.4 2012-01-01 TRUE   
6 2012-01-01 02:30:00  3866.        20.2 2012-01-01 TRUE   
daily_elec <- vic_elec |>
  as_tibble() |>
  mutate(Date = as.Date(Time)) |>
  group_by(Date) |>
  summarise(
    Demand = mean(Demand, na.rm = TRUE),
    Temperature = mean(Temperature, na.rm = TRUE),
    .groups = "drop"
  ) |>
  arrange(Date)

head(daily_elec)
# A tibble: 6 × 3
  Date       Demand Temperature
  <date>      <dbl>       <dbl>
1 2011-12-31  3751.        21.0
2 2012-01-01  4745.        26.6
3 2012-01-02  5739.        31.8
4 2012-01-03  5395.        24.6
5 2012-01-04  4454.        18.2
6 2012-01-05  4397.        17.8
summary(daily_elec)
      Date                Demand      Temperature    
 Min.   :2011-12-31   Min.   :3341   Min.   : 6.571  
 1st Qu.:2012-09-30   1st Qu.:4321   1st Qu.:12.537  
 Median :2013-07-01   Median :4658   Median :15.700  
 Mean   :2013-07-01   Mean   :4665   Mean   :16.269  
 3rd Qu.:2014-04-01   3rd Qu.:4981   3rd Qu.:19.161  
 Max.   :2014-12-31   Max.   :7322   Max.   :36.315  
write_csv(daily_elec, "daily_vic_elec.csv")
# Basic packages for data handling and plotting
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Load the dataset I created earlier from R
df = pd.read_csv("daily_vic_elec.csv")

# Make sure Date is treated like a real date
df["Date"] = pd.to_datetime(df["Date"])

# Create lag features so the model can learn from past demand
df["lag_1"] = df["Demand"].shift(1)   # yesterday
df["lag_7"] = df["Demand"].shift(7)   # same day last week

# Create a rolling average to smooth recent demand trends
df["rolling_7"] = df["Demand"].rolling(window=7).mean()

# Drop rows with missing values caused by lagging
df = df.dropna().reset_index(drop=True)

# Quick check
df.head()
        Date       Demand  Temperature        lag_1        lag_7    rolling_7
0 2012-01-07  4181.109798    24.098958  4277.889888  3751.442996  4741.414774
1 2012-01-08  4167.950307    20.223958  4181.109798  4745.380361  4658.924766
2 2012-01-09  4504.204934    19.161458  4167.950307  5739.395602  4482.468957
3 2012-01-10  4524.518915    16.042708  4504.204934  5394.902696  4358.128416
4 2012-01-11  4531.614063    14.815625  4524.518915  4454.007853  4369.215018
# Split data into train/test (keep time order intact)
train_size = int(len(df) * 0.8)
train = df.iloc[:train_size]
test = df.iloc[train_size:]

# Features the model will use
features = ["Temperature", "lag_1", "lag_7", "rolling_7"]

X_train = train[features]
y_train = train["Demand"]

X_test = test[features]
y_test = test["Demand"]

# BASELINE MODEL (Linear Regression)
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, mean_absolute_error

# Simple baseline so I have something to compare against
baseline_model = LinearRegression()
baseline_model.fit(X_train, y_train)
LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
baseline_preds = baseline_model.predict(X_test)

# Evaluate baseline
baseline_rmse = np.sqrt(mean_squared_error(y_test, baseline_preds))
baseline_mae = mean_absolute_error(y_test, baseline_preds)

print("Baseline Linear Regression RMSE:", baseline_rmse)
Baseline Linear Regression RMSE: 259.4068820700317
print("Baseline Linear Regression MAE:", baseline_mae)
Baseline Linear Regression MAE: 206.01892235083722
# Plot baseline vs actual
plt.figure(figsize=(10,5))
plt.plot(test["Date"], y_test, label="Actual")
plt.plot(test["Date"], baseline_preds, label="Predicted")
plt.title("Baseline Model: Actual vs Predicted Electricity Demand")
plt.xlabel("Date")
plt.ylabel("Demand")
plt.legend()
plt.xticks(rotation=45)
(array([16222., 16252., 16283., 16314., 16344., 16375., 16405., 16436.]), [Text(16222.0, 0, '2014-06'), Text(16252.0, 0, '2014-07'), Text(16283.0, 0, '2014-08'), Text(16314.0, 0, '2014-09'), Text(16344.0, 0, '2014-10'), Text(16375.0, 0, '2014-11'), Text(16405.0, 0, '2014-12'), Text(16436.0, 0, '2015-01')])
plt.tight_layout()
plt.show()

# PYTORCH FORECASTING SETUP
import torch
import lightning.pytorch as pl

from pytorch_forecasting import TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.models import TemporalFusionTransformer
from pytorch_forecasting.metrics import QuantileLoss

print("PyTorch imports worked")
PyTorch imports worked
# Create time index (required for PyTorch Forecasting)
df["time_idx"] = np.arange(len(df))

# Single time series ID
df["series"] = "electricity"

# Define where training ends
training_cutoff = df["time_idx"].max() - len(test)

# Drop Date for modeling (model only wants numeric/categorical inputs)
model_df = df.drop(columns=["Date"]).copy()
model_df = model_df.dropna().reset_index(drop=True)

# How much history the model sees and how far it predicts
max_encoder_length = 30
max_prediction_length = 7

# Build dataset object for PyTorch
training = TimeSeriesDataSet(
    model_df[model_df.time_idx <= training_cutoff],
    time_idx="time_idx",
    target="Demand",
    group_ids=["series"],

    min_encoder_length=max_encoder_length,
    max_encoder_length=max_encoder_length,

    min_prediction_length=1,
    max_prediction_length=max_prediction_length,

    static_categoricals=["series"],
    time_varying_known_reals=["time_idx"],

    # All variables the model learns from over time
    time_varying_unknown_reals=[
        "Demand",
        "Temperature",
        "lag_1",
        "lag_7",
        "rolling_7"
    ],

    target_normalizer=GroupNormalizer(groups=["series"]),
)

# Validation dataset
validation = TimeSeriesDataSet.from_dataset(
    training,
    model_df,
    predict=True,
    stop_randomization=True,
)

batch_size = 32

train_dataloader = training.to_dataloader(
    train=True,
    batch_size=batch_size,
    num_workers=0
)

val_dataloader = validation.to_dataloader(
    train=False,
    batch_size=batch_size,
    num_workers=0
)

# BUILD AND TRAIN TFT MODEL
tft_model = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=16,
    attention_head_size=1,
    dropout=0.1,
    loss=QuantileLoss(),
)

print("TFT model built")
TFT model built
trainer = pl.Trainer(
    max_epochs=5,
    accelerator="cpu",
    enable_checkpointing=False,
    logger=False
)

trainer.fit(
    tft_model,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader
)

Sanity Checking: |          | 0/? [00:00<?, ?it/s]
Sanity Checking: |          | 0/? [00:00<?, ?it/s]
Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 100%|##########| 1/1 [00:00<00:00, 21.57it/s]
                                                                           

Training: |          | 0/? [00:00<?, ?it/s]
Training: |          | 0/? [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/26 [00:00<?, ?it/s]
Epoch 0:   4%|3         | 1/26 [00:00<00:03,  8.01it/s]
Epoch 0:   4%|3         | 1/26 [00:00<00:03,  7.94it/s, train_loss_step=423.0]
Epoch 0:   8%|7         | 2/26 [00:00<00:02,  8.33it/s, train_loss_step=423.0]
Epoch 0:   8%|7         | 2/26 [00:00<00:02,  8.32it/s, train_loss_step=227.0]
Epoch 0:  12%|#1        | 3/26 [00:00<00:02,  8.84it/s, train_loss_step=227.0]
Epoch 0:  12%|#1        | 3/26 [00:00<00:02,  8.84it/s, train_loss_step=254.0]
Epoch 0:  15%|#5        | 4/26 [00:00<00:02,  8.19it/s, train_loss_step=254.0]
Epoch 0:  15%|#5        | 4/26 [00:00<00:02,  8.19it/s, train_loss_step=226.0]
Epoch 0:  19%|#9        | 5/26 [00:00<00:02,  7.69it/s, train_loss_step=226.0]
Epoch 0:  19%|#9        | 5/26 [00:00<00:02,  7.68it/s, train_loss_step=209.0]
Epoch 0:  23%|##3       | 6/26 [00:00<00:02,  7.37it/s, train_loss_step=209.0]
Epoch 0:  23%|##3       | 6/26 [00:00<00:02,  7.36it/s, train_loss_step=232.0]
Epoch 0:  27%|##6       | 7/26 [00:00<00:02,  7.19it/s, train_loss_step=232.0]
Epoch 0:  27%|##6       | 7/26 [00:00<00:02,  7.19it/s, train_loss_step=237.0]
Epoch 0:  31%|###       | 8/26 [00:01<00:02,  7.07it/s, train_loss_step=237.0]
Epoch 0:  31%|###       | 8/26 [00:01<00:02,  7.07it/s, train_loss_step=224.0]
Epoch 0:  35%|###4      | 9/26 [00:01<00:02,  6.96it/s, train_loss_step=224.0]
Epoch 0:  35%|###4      | 9/26 [00:01<00:02,  6.96it/s, train_loss_step=250.0]
Epoch 0:  38%|###8      | 10/26 [00:01<00:02,  6.89it/s, train_loss_step=250.0]
Epoch 0:  38%|###8      | 10/26 [00:01<00:02,  6.89it/s, train_loss_step=227.0]
Epoch 0:  42%|####2     | 11/26 [00:01<00:02,  6.82it/s, train_loss_step=227.0]
Epoch 0:  42%|####2     | 11/26 [00:01<00:02,  6.82it/s, train_loss_step=234.0]
Epoch 0:  46%|####6     | 12/26 [00:01<00:02,  6.85it/s, train_loss_step=234.0]
Epoch 0:  46%|####6     | 12/26 [00:01<00:02,  6.85it/s, train_loss_step=214.0]
Epoch 0:  50%|#####     | 13/26 [00:01<00:01,  7.05it/s, train_loss_step=214.0]
Epoch 0:  50%|#####     | 13/26 [00:01<00:01,  7.05it/s, train_loss_step=215.0]
Epoch 0:  54%|#####3    | 14/26 [00:01<00:01,  7.06it/s, train_loss_step=215.0]
Epoch 0:  54%|#####3    | 14/26 [00:01<00:01,  7.06it/s, train_loss_step=275.0]
Epoch 0:  58%|#####7    | 15/26 [00:02<00:01,  6.98it/s, train_loss_step=275.0]
Epoch 0:  58%|#####7    | 15/26 [00:02<00:01,  6.98it/s, train_loss_step=191.0]
Epoch 0:  62%|######1   | 16/26 [00:02<00:01,  6.92it/s, train_loss_step=191.0]
Epoch 0:  62%|######1   | 16/26 [00:02<00:01,  6.92it/s, train_loss_step=201.0]
Epoch 0:  65%|######5   | 17/26 [00:02<00:01,  6.86it/s, train_loss_step=201.0]
Epoch 0:  65%|######5   | 17/26 [00:02<00:01,  6.86it/s, train_loss_step=222.0]
Epoch 0:  69%|######9   | 18/26 [00:02<00:01,  6.78it/s, train_loss_step=222.0]
Epoch 0:  69%|######9   | 18/26 [00:02<00:01,  6.78it/s, train_loss_step=204.0]
Epoch 0:  73%|#######3  | 19/26 [00:02<00:01,  6.73it/s, train_loss_step=204.0]
Epoch 0:  73%|#######3  | 19/26 [00:02<00:01,  6.73it/s, train_loss_step=230.0]
Epoch 0:  77%|#######6  | 20/26 [00:03<00:00,  6.66it/s, train_loss_step=230.0]
Epoch 0:  77%|#######6  | 20/26 [00:03<00:00,  6.66it/s, train_loss_step=202.0]
Epoch 0:  81%|########  | 21/26 [00:03<00:00,  6.72it/s, train_loss_step=202.0]
Epoch 0:  81%|########  | 21/26 [00:03<00:00,  6.72it/s, train_loss_step=215.0]
Epoch 0:  85%|########4 | 22/26 [00:03<00:00,  6.84it/s, train_loss_step=215.0]
Epoch 0:  85%|########4 | 22/26 [00:03<00:00,  6.84it/s, train_loss_step=232.0]
Epoch 0:  88%|########8 | 23/26 [00:03<00:00,  6.95it/s, train_loss_step=232.0]
Epoch 0:  88%|########8 | 23/26 [00:03<00:00,  6.95it/s, train_loss_step=204.0]
Epoch 0:  92%|#########2| 24/26 [00:03<00:00,  7.07it/s, train_loss_step=204.0]
Epoch 0:  92%|#########2| 24/26 [00:03<00:00,  7.07it/s, train_loss_step=181.0]
Epoch 0:  96%|#########6| 25/26 [00:03<00:00,  7.17it/s, train_loss_step=181.0]
Epoch 0:  96%|#########6| 25/26 [00:03<00:00,  7.17it/s, train_loss_step=216.0]
Epoch 0: 100%|##########| 26/26 [00:03<00:00,  7.18it/s, train_loss_step=216.0]
Epoch 0: 100%|##########| 26/26 [00:03<00:00,  7.18it/s, train_loss_step=207.0]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

Validation DataLoader 0: 100%|##########| 1/1 [00:00<00:00, 25.59it/s]

                                                                      
Epoch 0: 100%|##########| 26/26 [00:03<00:00,  7.08it/s, train_loss_step=207.0, val_loss=346.0]
Epoch 0: 100%|##########| 26/26 [00:03<00:00,  7.08it/s, train_loss_step=207.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 0:   0%|          | 0/26 [00:00<?, ?it/s, train_loss_step=207.0, val_loss=346.0, train_loss_epoch=229.0]         
Epoch 1:   0%|          | 0/26 [00:00<?, ?it/s, train_loss_step=207.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:   4%|3         | 1/26 [00:00<00:04,  6.13it/s, train_loss_step=207.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:   4%|3         | 1/26 [00:00<00:04,  6.13it/s, train_loss_step=205.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:   8%|7         | 2/26 [00:00<00:03,  6.14it/s, train_loss_step=205.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:   8%|7         | 2/26 [00:00<00:03,  6.13it/s, train_loss_step=226.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  12%|#1        | 3/26 [00:00<00:03,  6.12it/s, train_loss_step=226.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  12%|#1        | 3/26 [00:00<00:03,  6.12it/s, train_loss_step=181.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  15%|#5        | 4/26 [00:00<00:03,  6.14it/s, train_loss_step=181.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  15%|#5        | 4/26 [00:00<00:03,  6.14it/s, train_loss_step=229.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  19%|#9        | 5/26 [00:00<00:03,  6.22it/s, train_loss_step=229.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  19%|#9        | 5/26 [00:00<00:03,  6.22it/s, train_loss_step=193.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  23%|##3       | 6/26 [00:00<00:03,  6.32it/s, train_loss_step=193.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  23%|##3       | 6/26 [00:00<00:03,  6.32it/s, train_loss_step=197.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  27%|##6       | 7/26 [00:01<00:02,  6.65it/s, train_loss_step=197.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  27%|##6       | 7/26 [00:01<00:02,  6.65it/s, train_loss_step=219.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  31%|###       | 8/26 [00:01<00:02,  7.00it/s, train_loss_step=219.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  31%|###       | 8/26 [00:01<00:02,  7.00it/s, train_loss_step=203.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  35%|###4      | 9/26 [00:01<00:02,  7.22it/s, train_loss_step=203.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  35%|###4      | 9/26 [00:01<00:02,  7.22it/s, train_loss_step=203.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  38%|###8      | 10/26 [00:01<00:02,  7.29it/s, train_loss_step=203.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  38%|###8      | 10/26 [00:01<00:02,  7.29it/s, train_loss_step=190.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  42%|####2     | 11/26 [00:01<00:02,  7.26it/s, train_loss_step=190.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  42%|####2     | 11/26 [00:01<00:02,  7.26it/s, train_loss_step=208.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  46%|####6     | 12/26 [00:01<00:01,  7.32it/s, train_loss_step=208.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  46%|####6     | 12/26 [00:01<00:01,  7.32it/s, train_loss_step=187.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  50%|#####     | 13/26 [00:01<00:01,  7.42it/s, train_loss_step=187.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  50%|#####     | 13/26 [00:01<00:01,  7.42it/s, train_loss_step=206.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  54%|#####3    | 14/26 [00:01<00:01,  7.36it/s, train_loss_step=206.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  54%|#####3    | 14/26 [00:01<00:01,  7.36it/s, train_loss_step=214.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  58%|#####7    | 15/26 [00:01<00:01,  7.54it/s, train_loss_step=214.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  58%|#####7    | 15/26 [00:01<00:01,  7.54it/s, train_loss_step=219.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  62%|######1   | 16/26 [00:02<00:01,  7.61it/s, train_loss_step=219.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  62%|######1   | 16/26 [00:02<00:01,  7.61it/s, train_loss_step=196.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  65%|######5   | 17/26 [00:02<00:01,  7.65it/s, train_loss_step=196.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  65%|######5   | 17/26 [00:02<00:01,  7.64it/s, train_loss_step=227.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  69%|######9   | 18/26 [00:02<00:01,  7.58it/s, train_loss_step=227.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  69%|######9   | 18/26 [00:02<00:01,  7.58it/s, train_loss_step=192.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  73%|#######3  | 19/26 [00:02<00:00,  7.52it/s, train_loss_step=192.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  73%|#######3  | 19/26 [00:02<00:00,  7.52it/s, train_loss_step=168.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  77%|#######6  | 20/26 [00:02<00:00,  7.48it/s, train_loss_step=168.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  77%|#######6  | 20/26 [00:02<00:00,  7.48it/s, train_loss_step=208.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  81%|########  | 21/26 [00:02<00:00,  7.43it/s, train_loss_step=208.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  81%|########  | 21/26 [00:02<00:00,  7.43it/s, train_loss_step=183.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  85%|########4 | 22/26 [00:02<00:00,  7.37it/s, train_loss_step=183.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  85%|########4 | 22/26 [00:02<00:00,  7.37it/s, train_loss_step=180.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  88%|########8 | 23/26 [00:03<00:00,  7.32it/s, train_loss_step=180.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  88%|########8 | 23/26 [00:03<00:00,  7.32it/s, train_loss_step=171.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  92%|#########2| 24/26 [00:03<00:00,  7.28it/s, train_loss_step=171.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  92%|#########2| 24/26 [00:03<00:00,  7.28it/s, train_loss_step=224.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  96%|#########6| 25/26 [00:03<00:00,  7.23it/s, train_loss_step=224.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1:  96%|#########6| 25/26 [00:03<00:00,  7.23it/s, train_loss_step=190.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1: 100%|##########| 26/26 [00:03<00:00,  7.19it/s, train_loss_step=190.0, val_loss=346.0, train_loss_epoch=229.0]
Epoch 1: 100%|##########| 26/26 [00:03<00:00,  7.19it/s, train_loss_step=208.0, val_loss=346.0, train_loss_epoch=229.0]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

Validation DataLoader 0: 100%|##########| 1/1 [00:00<00:00, 32.12it/s]

                                                                      
Epoch 1: 100%|##########| 26/26 [00:03<00:00,  7.11it/s, train_loss_step=208.0, val_loss=289.0, train_loss_epoch=229.0]
Epoch 1: 100%|##########| 26/26 [00:03<00:00,  7.10it/s, train_loss_step=208.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 1:   0%|          | 0/26 [00:00<?, ?it/s, train_loss_step=208.0, val_loss=289.0, train_loss_epoch=201.0]         
Epoch 2:   0%|          | 0/26 [00:00<?, ?it/s, train_loss_step=208.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:   4%|3         | 1/26 [00:00<00:04,  5.62it/s, train_loss_step=208.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:   4%|3         | 1/26 [00:00<00:04,  5.62it/s, train_loss_step=183.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:   8%|7         | 2/26 [00:00<00:03,  6.29it/s, train_loss_step=183.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:   8%|7         | 2/26 [00:00<00:03,  6.27it/s, train_loss_step=179.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  12%|#1        | 3/26 [00:00<00:03,  6.05it/s, train_loss_step=179.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  12%|#1        | 3/26 [00:00<00:03,  6.05it/s, train_loss_step=192.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  15%|#5        | 4/26 [00:00<00:03,  5.95it/s, train_loss_step=192.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  15%|#5        | 4/26 [00:00<00:03,  5.95it/s, train_loss_step=228.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  19%|#9        | 5/26 [00:00<00:03,  5.90it/s, train_loss_step=228.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  19%|#9        | 5/26 [00:00<00:03,  5.90it/s, train_loss_step=192.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  23%|##3       | 6/26 [00:01<00:03,  5.80it/s, train_loss_step=192.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  23%|##3       | 6/26 [00:01<00:03,  5.80it/s, train_loss_step=186.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  27%|##6       | 7/26 [00:01<00:03,  5.74it/s, train_loss_step=186.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  27%|##6       | 7/26 [00:01<00:03,  5.74it/s, train_loss_step=198.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  31%|###       | 8/26 [00:01<00:03,  5.70it/s, train_loss_step=198.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  31%|###       | 8/26 [00:01<00:03,  5.70it/s, train_loss_step=233.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  35%|###4      | 9/26 [00:01<00:03,  5.67it/s, train_loss_step=233.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  35%|###4      | 9/26 [00:01<00:03,  5.67it/s, train_loss_step=187.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  38%|###8      | 10/26 [00:01<00:02,  5.66it/s, train_loss_step=187.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  38%|###8      | 10/26 [00:01<00:02,  5.66it/s, train_loss_step=196.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  42%|####2     | 11/26 [00:01<00:02,  5.66it/s, train_loss_step=196.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  42%|####2     | 11/26 [00:01<00:02,  5.66it/s, train_loss_step=186.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  46%|####6     | 12/26 [00:02<00:02,  5.63it/s, train_loss_step=186.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  46%|####6     | 12/26 [00:02<00:02,  5.63it/s, train_loss_step=186.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  50%|#####     | 13/26 [00:02<00:02,  5.63it/s, train_loss_step=186.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  50%|#####     | 13/26 [00:02<00:02,  5.63it/s, train_loss_step=260.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  54%|#####3    | 14/26 [00:02<00:02,  5.62it/s, train_loss_step=260.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  54%|#####3    | 14/26 [00:02<00:02,  5.62it/s, train_loss_step=209.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  58%|#####7    | 15/26 [00:02<00:01,  5.62it/s, train_loss_step=209.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  58%|#####7    | 15/26 [00:02<00:01,  5.62it/s, train_loss_step=189.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  62%|######1   | 16/26 [00:02<00:01,  5.62it/s, train_loss_step=189.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  62%|######1   | 16/26 [00:02<00:01,  5.62it/s, train_loss_step=215.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  65%|######5   | 17/26 [00:03<00:01,  5.63it/s, train_loss_step=215.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  65%|######5   | 17/26 [00:03<00:01,  5.63it/s, train_loss_step=165.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  69%|######9   | 18/26 [00:03<00:01,  5.64it/s, train_loss_step=165.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  69%|######9   | 18/26 [00:03<00:01,  5.64it/s, train_loss_step=196.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  73%|#######3  | 19/26 [00:03<00:01,  5.65it/s, train_loss_step=196.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  73%|#######3  | 19/26 [00:03<00:01,  5.65it/s, train_loss_step=165.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  77%|#######6  | 20/26 [00:03<00:01,  5.66it/s, train_loss_step=165.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  77%|#######6  | 20/26 [00:03<00:01,  5.66it/s, train_loss_step=191.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  81%|########  | 21/26 [00:03<00:00,  5.69it/s, train_loss_step=191.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  81%|########  | 21/26 [00:03<00:00,  5.69it/s, train_loss_step=234.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  85%|########4 | 22/26 [00:03<00:00,  5.72it/s, train_loss_step=234.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  85%|########4 | 22/26 [00:03<00:00,  5.72it/s, train_loss_step=202.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  88%|########8 | 23/26 [00:04<00:00,  5.74it/s, train_loss_step=202.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  88%|########8 | 23/26 [00:04<00:00,  5.74it/s, train_loss_step=189.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  92%|#########2| 24/26 [00:04<00:00,  5.76it/s, train_loss_step=189.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  92%|#########2| 24/26 [00:04<00:00,  5.76it/s, train_loss_step=204.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  96%|#########6| 25/26 [00:04<00:00,  5.80it/s, train_loss_step=204.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2:  96%|#########6| 25/26 [00:04<00:00,  5.80it/s, train_loss_step=181.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2: 100%|##########| 26/26 [00:04<00:00,  5.81it/s, train_loss_step=181.0, val_loss=289.0, train_loss_epoch=201.0]
Epoch 2: 100%|##########| 26/26 [00:04<00:00,  5.81it/s, train_loss_step=200.0, val_loss=289.0, train_loss_epoch=201.0]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

Validation DataLoader 0: 100%|##########| 1/1 [00:00<00:00, 29.24it/s]

                                                                      
Epoch 2: 100%|##########| 26/26 [00:04<00:00,  5.76it/s, train_loss_step=200.0, val_loss=333.0, train_loss_epoch=201.0]
Epoch 2: 100%|##########| 26/26 [00:04<00:00,  5.75it/s, train_loss_step=200.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 2:   0%|          | 0/26 [00:00<?, ?it/s, train_loss_step=200.0, val_loss=333.0, train_loss_epoch=198.0]         
Epoch 3:   0%|          | 0/26 [00:00<?, ?it/s, train_loss_step=200.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:   4%|3         | 1/26 [00:00<00:04,  6.22it/s, train_loss_step=200.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:   4%|3         | 1/26 [00:00<00:04,  6.22it/s, train_loss_step=222.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:   8%|7         | 2/26 [00:00<00:03,  7.10it/s, train_loss_step=222.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:   8%|7         | 2/26 [00:00<00:03,  7.10it/s, train_loss_step=196.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  12%|#1        | 3/26 [00:00<00:02,  7.69it/s, train_loss_step=196.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  12%|#1        | 3/26 [00:00<00:02,  7.69it/s, train_loss_step=193.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  15%|#5        | 4/26 [00:00<00:02,  8.03it/s, train_loss_step=193.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  15%|#5        | 4/26 [00:00<00:02,  8.03it/s, train_loss_step=199.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  19%|#9        | 5/26 [00:00<00:02,  7.65it/s, train_loss_step=199.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  19%|#9        | 5/26 [00:00<00:02,  7.65it/s, train_loss_step=225.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  23%|##3       | 6/26 [00:00<00:02,  7.50it/s, train_loss_step=225.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  23%|##3       | 6/26 [00:00<00:02,  7.50it/s, train_loss_step=179.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  27%|##6       | 7/26 [00:00<00:02,  7.42it/s, train_loss_step=179.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  27%|##6       | 7/26 [00:00<00:02,  7.41it/s, train_loss_step=241.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  31%|###       | 8/26 [00:01<00:02,  7.30it/s, train_loss_step=241.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  31%|###       | 8/26 [00:01<00:02,  7.29it/s, train_loss_step=189.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  35%|###4      | 9/26 [00:01<00:02,  7.25it/s, train_loss_step=189.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  35%|###4      | 9/26 [00:01<00:02,  7.25it/s, train_loss_step=185.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  38%|###8      | 10/26 [00:01<00:02,  7.32it/s, train_loss_step=185.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  38%|###8      | 10/26 [00:01<00:02,  7.32it/s, train_loss_step=216.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  42%|####2     | 11/26 [00:01<00:02,  7.26it/s, train_loss_step=216.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  42%|####2     | 11/26 [00:01<00:02,  7.26it/s, train_loss_step=201.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  46%|####6     | 12/26 [00:01<00:01,  7.19it/s, train_loss_step=201.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  46%|####6     | 12/26 [00:01<00:01,  7.18it/s, train_loss_step=186.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  50%|#####     | 13/26 [00:01<00:01,  7.09it/s, train_loss_step=186.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  50%|#####     | 13/26 [00:01<00:01,  7.09it/s, train_loss_step=198.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  54%|#####3    | 14/26 [00:02<00:01,  7.00it/s, train_loss_step=198.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  54%|#####3    | 14/26 [00:02<00:01,  7.00it/s, train_loss_step=204.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  58%|#####7    | 15/26 [00:02<00:01,  6.99it/s, train_loss_step=204.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  58%|#####7    | 15/26 [00:02<00:01,  6.99it/s, train_loss_step=194.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  62%|######1   | 16/26 [00:02<00:01,  6.98it/s, train_loss_step=194.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  62%|######1   | 16/26 [00:02<00:01,  6.98it/s, train_loss_step=197.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  65%|######5   | 17/26 [00:02<00:01,  7.00it/s, train_loss_step=197.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  65%|######5   | 17/26 [00:02<00:01,  7.00it/s, train_loss_step=211.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  69%|######9   | 18/26 [00:02<00:01,  7.02it/s, train_loss_step=211.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  69%|######9   | 18/26 [00:02<00:01,  7.02it/s, train_loss_step=194.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  73%|#######3  | 19/26 [00:02<00:00,  7.04it/s, train_loss_step=194.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  73%|#######3  | 19/26 [00:02<00:00,  7.04it/s, train_loss_step=183.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  77%|#######6  | 20/26 [00:02<00:00,  7.11it/s, train_loss_step=183.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  77%|#######6  | 20/26 [00:02<00:00,  7.11it/s, train_loss_step=174.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  81%|########  | 21/26 [00:02<00:00,  7.24it/s, train_loss_step=174.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  81%|########  | 21/26 [00:02<00:00,  7.24it/s, train_loss_step=183.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  85%|########4 | 22/26 [00:03<00:00,  7.33it/s, train_loss_step=183.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  85%|########4 | 22/26 [00:03<00:00,  7.32it/s, train_loss_step=188.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  88%|########8 | 23/26 [00:03<00:00,  7.39it/s, train_loss_step=188.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  88%|########8 | 23/26 [00:03<00:00,  7.39it/s, train_loss_step=173.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  92%|#########2| 24/26 [00:03<00:00,  7.47it/s, train_loss_step=173.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  92%|#########2| 24/26 [00:03<00:00,  7.47it/s, train_loss_step=236.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  96%|#########6| 25/26 [00:03<00:00,  7.52it/s, train_loss_step=236.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3:  96%|#########6| 25/26 [00:03<00:00,  7.52it/s, train_loss_step=200.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3: 100%|##########| 26/26 [00:03<00:00,  7.61it/s, train_loss_step=200.0, val_loss=333.0, train_loss_epoch=198.0]
Epoch 3: 100%|##########| 26/26 [00:03<00:00,  7.61it/s, train_loss_step=177.0, val_loss=333.0, train_loss_epoch=198.0]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

Validation DataLoader 0: 100%|##########| 1/1 [00:00<00:00, 40.54it/s]

                                                                      
Epoch 3: 100%|##########| 26/26 [00:03<00:00,  7.54it/s, train_loss_step=177.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 3: 100%|##########| 26/26 [00:03<00:00,  7.53it/s, train_loss_step=177.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 3:   0%|          | 0/26 [00:00<?, ?it/s, train_loss_step=177.0, val_loss=323.0, train_loss_epoch=198.0]         
Epoch 4:   0%|          | 0/26 [00:00<?, ?it/s, train_loss_step=177.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:   4%|3         | 1/26 [00:00<00:02, 11.07it/s, train_loss_step=177.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:   4%|3         | 1/26 [00:00<00:02, 11.07it/s, train_loss_step=239.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:   8%|7         | 2/26 [00:00<00:02,  8.90it/s, train_loss_step=239.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:   8%|7         | 2/26 [00:00<00:02,  8.85it/s, train_loss_step=193.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  12%|#1        | 3/26 [00:00<00:02,  8.24it/s, train_loss_step=193.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  12%|#1        | 3/26 [00:00<00:02,  8.24it/s, train_loss_step=181.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  15%|#5        | 4/26 [00:00<00:02,  7.89it/s, train_loss_step=181.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  15%|#5        | 4/26 [00:00<00:02,  7.89it/s, train_loss_step=188.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  19%|#9        | 5/26 [00:00<00:02,  7.69it/s, train_loss_step=188.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  19%|#9        | 5/26 [00:00<00:02,  7.68it/s, train_loss_step=179.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  23%|##3       | 6/26 [00:00<00:02,  7.60it/s, train_loss_step=179.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  23%|##3       | 6/26 [00:00<00:02,  7.59it/s, train_loss_step=220.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  27%|##6       | 7/26 [00:00<00:02,  7.50it/s, train_loss_step=220.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  27%|##6       | 7/26 [00:00<00:02,  7.49it/s, train_loss_step=187.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  31%|###       | 8/26 [00:01<00:02,  7.42it/s, train_loss_step=187.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  31%|###       | 8/26 [00:01<00:02,  7.42it/s, train_loss_step=222.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  35%|###4      | 9/26 [00:01<00:02,  7.35it/s, train_loss_step=222.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  35%|###4      | 9/26 [00:01<00:02,  7.35it/s, train_loss_step=215.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  38%|###8      | 10/26 [00:01<00:02,  7.32it/s, train_loss_step=215.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  38%|###8      | 10/26 [00:01<00:02,  7.32it/s, train_loss_step=178.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  42%|####2     | 11/26 [00:01<00:02,  7.31it/s, train_loss_step=178.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  42%|####2     | 11/26 [00:01<00:02,  7.31it/s, train_loss_step=178.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  46%|####6     | 12/26 [00:01<00:01,  7.29it/s, train_loss_step=178.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  46%|####6     | 12/26 [00:01<00:01,  7.29it/s, train_loss_step=174.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  50%|#####     | 13/26 [00:01<00:01,  7.31it/s, train_loss_step=174.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  50%|#####     | 13/26 [00:01<00:01,  7.31it/s, train_loss_step=182.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  54%|#####3    | 14/26 [00:01<00:01,  7.32it/s, train_loss_step=182.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  54%|#####3    | 14/26 [00:01<00:01,  7.32it/s, train_loss_step=207.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  58%|#####7    | 15/26 [00:02<00:01,  7.34it/s, train_loss_step=207.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  58%|#####7    | 15/26 [00:02<00:01,  7.34it/s, train_loss_step=177.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  62%|######1   | 16/26 [00:02<00:01,  7.37it/s, train_loss_step=177.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  62%|######1   | 16/26 [00:02<00:01,  7.36it/s, train_loss_step=203.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  65%|######5   | 17/26 [00:02<00:01,  7.40it/s, train_loss_step=203.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  65%|######5   | 17/26 [00:02<00:01,  7.40it/s, train_loss_step=218.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  69%|######9   | 18/26 [00:02<00:01,  7.50it/s, train_loss_step=218.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  69%|######9   | 18/26 [00:02<00:01,  7.50it/s, train_loss_step=232.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  73%|#######3  | 19/26 [00:02<00:00,  7.54it/s, train_loss_step=232.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  73%|#######3  | 19/26 [00:02<00:00,  7.54it/s, train_loss_step=198.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  77%|#######6  | 20/26 [00:02<00:00,  7.58it/s, train_loss_step=198.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  77%|#######6  | 20/26 [00:02<00:00,  7.57it/s, train_loss_step=193.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  81%|########  | 21/26 [00:02<00:00,  7.61it/s, train_loss_step=193.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  81%|########  | 21/26 [00:02<00:00,  7.60it/s, train_loss_step=180.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  85%|########4 | 22/26 [00:02<00:00,  7.62it/s, train_loss_step=180.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  85%|########4 | 22/26 [00:02<00:00,  7.62it/s, train_loss_step=187.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  88%|########8 | 23/26 [00:02<00:00,  7.68it/s, train_loss_step=187.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  88%|########8 | 23/26 [00:02<00:00,  7.67it/s, train_loss_step=187.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  92%|#########2| 24/26 [00:03<00:00,  7.73it/s, train_loss_step=187.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  92%|#########2| 24/26 [00:03<00:00,  7.73it/s, train_loss_step=196.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  96%|#########6| 25/26 [00:03<00:00,  7.75it/s, train_loss_step=196.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4:  96%|#########6| 25/26 [00:03<00:00,  7.75it/s, train_loss_step=226.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4: 100%|##########| 26/26 [00:03<00:00,  7.82it/s, train_loss_step=226.0, val_loss=323.0, train_loss_epoch=198.0]
Epoch 4: 100%|##########| 26/26 [00:03<00:00,  7.82it/s, train_loss_step=198.0, val_loss=323.0, train_loss_epoch=198.0]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

Validation DataLoader 0: 100%|##########| 1/1 [00:00<00:00, 45.35it/s]

                                                                      
Epoch 4: 100%|##########| 26/26 [00:03<00:00,  7.75it/s, train_loss_step=198.0, val_loss=310.0, train_loss_epoch=198.0]
Epoch 4: 100%|##########| 26/26 [00:03<00:00,  7.74it/s, train_loss_step=198.0, val_loss=310.0, train_loss_epoch=198.0]
Epoch 4: 100%|##########| 26/26 [00:03<00:00,  7.74it/s, train_loss_step=198.0, val_loss=310.0, train_loss_epoch=198.0]
print("TFT training complete")
TFT training complete
# PREDICTIONS
tft_preds = tft_model.predict(val_dataloader)
tft_preds = tft_preds.numpy().flatten()

tft_actual = y_test.iloc[:len(tft_preds)]
tft_dates = test["Date"].iloc[:len(tft_preds)]

# Plot TFT results
plt.figure(figsize=(10,5))
plt.plot(tft_dates, tft_actual, label="Actual")
plt.plot(tft_dates, tft_preds, label="TFT Predicted")
plt.title("TFT Model: Actual vs Predicted Electricity Demand")
plt.xlabel("Date")
plt.ylabel("Demand")
plt.legend()
plt.xticks(rotation=45)
(array([16218., 16219., 16220., 16221., 16222., 16223., 16224.]), [Text(16218.0, 0, '2014-05-28'), Text(16219.0, 0, '2014-05-29'), Text(16220.0, 0, '2014-05-30'), Text(16221.0, 0, '2014-05-31'), Text(16222.0, 0, '2014-06-01'), Text(16223.0, 0, '2014-06-02'), Text(16224.0, 0, '2014-06-03')])
plt.tight_layout()
plt.show()

# EVALUATE TFT
tft_rmse = np.sqrt(mean_squared_error(tft_actual, tft_preds))
tft_mae = mean_absolute_error(tft_actual, tft_preds)

print("TFT RMSE:", tft_rmse)
TFT RMSE: 348.13858198472576
print("TFT MAE:", tft_mae)
TFT MAE: 308.6852025431546