GINE-0.5 / model.py
ISeeTheFuture's picture
first commit
42eff2f
raw
history blame contribute delete
583 Bytes
import torch.nn as nn
class GPSCorrectionLSTM(nn.Module):
def __init__(self, input_size, hidden_size=128, num_layers=2,
output_size=2, dropout=0.3):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
batch_first=True, dropout=dropout)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
out, _ = self.lstm(x)
h_last = out[:, -1, :]
h_drop = self.dropout(h_last)
return self.fc(h_drop)