Infatoshi/smolvla
This repository contains a smolvla_base
policy trained with the lerobot
framework.
Model Description
This model is a Vision-Language-Action (VLA) policy that can take visual observations, proprioceptive states, and a language instruction to predict robot actions.
- Policy Type:
smolvla
- Dataset:
gribok201/smolvla_koch4
- VLM Backbone:
HuggingFaceTB/SmolVLM2-500M-Video-Instruct
- Trained Steps:
10000
I/O Schema
Input Features:
observation.image
: typeVISUAL
, shape[3, 256, 256]
observation.image2
: typeVISUAL
, shape[3, 256, 256]
observation.image3
: typeVISUAL
, shape[3, 256, 256]
observation.state
: typeSTATE
, shape[6]
Output Features:
action
: typeACTION
, shape[6]
Image Preprocessing:
Images are expected to be resized to [512, 512]
before being passed to the model.
How to Use
This model can be loaded using transformers.AutoModel
with trust_remote_code=True
.
You MUST have lerobot
installed in your environment for this to work.
(pip install lerobot
)
from transformers import AutoModel
import torch
from PIL import Image
import torchvision.transforms as T
# Replace with your model's repo_id
repo_id = "Infatoshi/smolvla"
# Load the model - CRITICAL: trust_remote_code=True
# This executes the custom code in modeling_lerobot_policy.py
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
model.eval()
print("Model loaded successfully!")
# Example Inference:
# Create dummy inputs matching the model's expected schema.
resize_shape = tuple(model.config.resize_imgs_with_padding)
state_shape = tuple(model.config.input_features["observation.state"]["shape"])
# Dummy observations dictionary
dummy_observations = {
"state": torch.randn(1, *state_shape),
"images": {
"usb": torch.randn(1, 3, *resize_shape),
"brio": torch.randn(1, 3, *resize_shape),
}
}
dummy_language_instruction = "pick up the cube"
with torch.no_grad():
output = model(
observations=dummy_observations,
language_instruction=dummy_language_instruction
)
print("Inference output (predicted actions):", output)
print("Output shape:", output.shape)
- Downloads last month
- 11