# SPDX-License-Identifier: Apache-2.0 # Copyright 2025 ISeeTheFuture import os, json, argparse, warnings, joblib import numpy as np import pandas as pd import torch from typing import List # === Features used at training time === FEATURE_COLS: List[str] = [ "latitude","longitude","altitude", "accelerometer_x","accelerometer_y","accelerometer_z", "gyroscope_x","gyroscope_y","gyroscope_z", "compass" ] # History window length for the model (uses the last 50 rows to predict the next row) HIST_LEN_DEFAULT = 50 # requires at least HIST_LEN+1 (=51) rows to produce one output # === Default file locations (relative to the model repo root) === DEFAULT_WEIGHTS = "1753670088.7075965_lstm_corr.pth" DEFAULT_SCALER_X = "scalers/1753670088.7075965_scaler_X.pkl" DEFAULT_SCALER_Y = "scalers/1753670088.7075965_scaler_y.pkl" DEFAULT_CONFIG = "config.json" # === Model class === from model import GPSCorrectionLSTM # __init__(input_size, hidden_size=128, num_layers=2, dropout=0.3) def load_config(cfg_path: str) -> dict: if os.path.exists(cfg_path): with open(cfg_path, "r") as f: return json.load(f) return {} def build_model(input_size: int, cfg: dict) -> torch.nn.Module: """Instantiate the model with hyperparameters from config.json if available.""" hidden_size = int(cfg.get("hidden_size", 128)) num_layers = int(cfg.get("num_layers", 2)) dropout = float(cfg.get("dropout", 0.3)) # output_size was 2 in training (res_lat, res_lon); keep flexible if your class needs it. try: model = GPSCorrectionLSTM(input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout) except TypeError: model = GPSCorrectionLSTM(input_size, hidden_size=hidden_size, num_layers=num_layers) model.eval() return model def load_scaler(path: str): """Load a Joblib scaler if present; otherwise continue without scaling.""" if os.path.exists(path): return joblib.load(path) warnings.warn(f"[WARN] scaler not found: {path}. Proceeding without scaling.") return None def load_df_from_csv(path: str) -> pd.DataFrame: """Load CSV, sort by timestamp if present, and validate feature columns.""" df = pd.read_csv(path) if "timestamp" in df.columns: df = df.sort_values("timestamp") missing = [c for c in FEATURE_COLS if c not in df.columns] if missing: raise ValueError(f"CSV is missing columns: {missing}") return df.reset_index(drop=True) def scale_window(X_win: np.ndarray, scaler_X): """Apply feature scaler to a single (T,F) window if provided.""" if scaler_X is None: return X_win T, F = X_win.shape return scaler_X.transform(X_win.reshape(-1, F)).reshape(T, F) def inverse_y(y: np.ndarray, scaler_y): """Inverse-transform a single (2,) or (3,) prediction if a target scaler is provided.""" if scaler_y is None: return y return scaler_y.inverse_transform(y.reshape(1, -1)).reshape(-1) def predict_next_residual(model: torch.nn.Module, X_win_tf: np.ndarray, device: str = "cpu") -> np.ndarray: """Predict next-step residual [res_lat, res_lon(, res_alt?)] from a (HIST_LEN,F) window.""" x = torch.from_numpy(X_win_tf.astype(np.float32)).unsqueeze(0).to(device) # (1, T, F) with torch.no_grad(): y = model(x).squeeze(0).detach().cpu().numpy() return y # shape: (2,) or (3,) # python inference.py --csv samples/sample.csv def main(): ap = argparse.ArgumentParser( description="Rolling inference for next-step GNSS residuals using an LSTM model. " "Uses the last HIST_LEN rows to predict the next row. " "If the CSV has N rows and N >= HIST_LEN+1, this script outputs corrected coordinates " "for rows [HIST_LEN ... N-1] (i.e., 51st to last)." ) src = ap.add_mutually_exclusive_group(required=True) src.add_argument("--json", type=str, help="JSON string of shape [T, F]") src.add_argument("--json-file", type=str, help="Path to a JSON file (shape [T, F])") src.add_argument("--csv", type=str, help="Path to a CSV with columns: " + ",".join(FEATURE_COLS)) ap.add_argument("--weights", default=DEFAULT_WEIGHTS, help="Model weights (state_dict or full model object).") ap.add_argument("--scaler-x", default=DEFAULT_SCALER_X, help="Feature scaler (Joblib).") ap.add_argument("--scaler-y", default=DEFAULT_SCALER_Y, help="Target scaler (Joblib).") ap.add_argument("--config", default=DEFAULT_CONFIG, help="Model hyperparameters (config.json).") ap.add_argument("--hist-len", type=int, default=HIST_LEN_DEFAULT, help="History window length used by the model (default: 50).") args = ap.parse_args() # 1) Load input if args.json: arr = np.asarray(json.loads(args.json), dtype=np.float32) timestamps = None elif args.json_file: with open(args.json_file, "r") as f: arr = np.asarray(json.load(f), dtype=np.float32) timestamps = None else: df = load_df_from_csv(args.csv) arr = df[FEATURE_COLS].to_numpy(dtype=np.float32) timestamps = df["timestamp"].to_numpy() if "timestamp" in df.columns else None T, F = arr.shape H = int(args.hist_len) if F != len(FEATURE_COLS): raise ValueError(f"Input feature dimension must be {len(FEATURE_COLS)}, got {F}.") # 2) Build & load model and scalers device = "cuda" if torch.cuda.is_available() else "cpu" cfg = load_config(args.config) model = build_model(input_size=F, cfg=cfg).to(device) state = torch.load(args.weights, map_location=device) try: model.load_state_dict(state) except Exception: model = state.to(device) model.eval() scaler_X = load_scaler(args.scaler_x) scaler_y = load_scaler(args.scaler_y) results = [] # Rolling inference for indices i = H .. T-1 # Each step uses arr[i-H : i] as input, and adds residual to noisy GNSS at i. for i in range(H, T): X_win = arr[i - H : i, :] # (H, F) X_win_tf = scale_window(X_win, scaler_X) y_pred = predict_next_residual(model, X_win_tf, device=device) # (2,) or (3,) y_pred_deg = inverse_y(y_pred, scaler_y) res_lat = float(y_pred_deg[0]) res_lon = float(y_pred_deg[1]) # Noisy GNSS at step i (the "next" row after the window) noisy_lat = float(arr[i, 0]) noisy_lon = float(arr[i, 1]) out = { "index": int(i), # 0-based row index in the input "noisy_next_lat_deg": noisy_lat, "noisy_next_lon_deg": noisy_lon, "pred_residual_lat_deg": res_lat, "pred_residual_lon_deg": res_lon, "corrected_next_lat_deg": noisy_lat + res_lat, "corrected_next_lon_deg": noisy_lon + res_lon, } # If model outputs altitude residual too if y_pred_deg.shape[0] >= 3: res_alt = float(y_pred_deg[2]) noisy_alt = float(arr[i, 2]) out.update({ "noisy_next_alt_m": noisy_alt, "pred_residual_alt": res_alt, "corrected_next_alt_m": noisy_alt + res_alt }) if timestamps is not None: out["timestamp"] = float(timestamps[i]) results.append(out) print(json.dumps({ "history_len": H, "total_rows": T, "outputs": results }, ensure_ascii=False, indent=2)) if __name__ == "__main__": main()