File size: 7,416 Bytes
63ed3a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import argparse
import json
import sys
from pathlib import Path
from typing import Any

from fastapi.openapi.utils import get_openapi

try:
    import yaml

    YAML_AVAILABLE = True
except ImportError:
    YAML_AVAILABLE = False


def get_app_instance():
    """
    Import and return the FastAPI app instance.

    Returns:
        FastAPI: The main FastAPI application instance

    Raises:
        ImportError: If the app cannot be imported

    """
    try:
        from inference_server.main import app

        return app
    except ImportError as e:
        print(f"❌ Error importing FastAPI app: {e}")
        print(
            "Make sure you're running this from the correct directory and all dependencies are installed"
        )
        sys.exit(1)


def create_custom_openapi_schema(app) -> dict[str, Any]:
    """
    Create a customized OpenAPI schema with enhanced metadata.

    Args:
        app: FastAPI application instance

    Returns:
        Dict[str, Any]: Complete OpenAPI schema dictionary

    """
    if app.openapi_schema:
        return app.openapi_schema

    # Generate the base OpenAPI schema
    openapi_schema = get_openapi(
        title="Inference Server",
        version="1.0.0",
        summary="ACT Model Inference Server for Real-time Robot Control",
        routes=app.routes,
        openapi_version="3.1.0",
    )

    # Add tags with descriptions for better organization
    openapi_schema["tags"] = [
        {"name": "Health", "description": "Health check and server status endpoints"},
        {
            "name": "Sessions",
            "description": "Inference session management - create, control, and monitor AI sessions",
        },
        {
            "name": "Control",
            "description": "Session control operations - start, stop, restart inference",
        },
        {
            "name": "Debug",
            "description": "Debug and monitoring endpoints for system diagnostics",
        },
    ]

    # Add security schemes if needed (for future authentication)
    openapi_schema["components"] = openapi_schema.get("components", {})
    openapi_schema["components"]["securitySchemes"] = {
        "BearerAuth": {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"},
        "ApiKeyAuth": {"type": "apiKey", "in": "header", "name": "X-API-Key"},
    }

    # Cache the schema
    app.openapi_schema = openapi_schema
    return openapi_schema


def export_openapi_schema(
    output_file: str | None = None, format_type: str = "json"
) -> dict[str, Any]:
    """
    Export the OpenAPI schema to a file or return it as a dictionary.

    Args:
        output_file (str, optional): Path to save the schema. If None, returns the schema dict.
        format_type (str): Format to export - 'json' or 'yaml'

    Returns:
        Dict[str, Any]: The OpenAPI schema dictionary

    Raises:
        ValueError: If format_type is not supported
        ImportError: If YAML is requested but not available

    """
    if format_type not in {"json", "yaml"}:
        msg = f"Unsupported format: {format_type}. Use 'json' or 'yaml'"
        raise ValueError(msg)

    if format_type == "yaml" and not YAML_AVAILABLE:
        msg = "PyYAML is required for YAML export. Install with: pip install pyyaml"
        raise ImportError(msg)

    # Get the FastAPI app and generate schema
    app = get_app_instance()
    openapi_schema = create_custom_openapi_schema(app)

    # If no output file specified, return the schema
    if output_file is None:
        return openapi_schema

    # Ensure output directory exists
    output_path = Path(output_file)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    # Save to file
    try:
        with open(output_path, "w", encoding="utf-8") as f:
            if format_type == "json":
                json.dump(openapi_schema, f, indent=2, ensure_ascii=False)
            else:  # yaml
                yaml.dump(
                    openapi_schema,
                    f,
                    default_flow_style=False,
                    allow_unicode=True,
                    sort_keys=False,
                )

        print(f"✅ OpenAPI schema exported to {output_path}")
        print(f"📄 Format: {format_type.upper()}")
        print(
            f"📊 Endpoints: {len([route for route in app.routes if hasattr(route, 'methods')])}"
        )

        return openapi_schema

    except Exception as e:
        print(f"❌ Error saving OpenAPI schema: {e}")
        sys.exit(1)


def main():
    """CLI entry point for OpenAPI export."""
    parser = argparse.ArgumentParser(
        description="Export OpenAPI schema from Inference Server",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Export as JSON (default)
  python -m inference_server.export_openapi

  # Export as YAML
  python -m inference_server.export_openapi --format yaml

  # Custom output file
  python -m inference_server.export_openapi --output api_schema.json

  # Specify both format and output
  python -m inference_server.export_openapi --format yaml --output docs/openapi.yaml
        """,
    )

    parser.add_argument(
        "--output",
        "-o",
        default=None,
        help="Output file path (default: openapi.json or openapi.yaml based on format)",
    )

    parser.add_argument(
        "--format",
        "-f",
        choices=["json", "yaml"],
        default="json",
        help="Output format (default: json)",
    )

    parser.add_argument(
        "--validate",
        action="store_true",
        help="Validate the generated schema (requires openapi-spec-validator)",
    )

    parser.add_argument(
        "--print",
        action="store_true",
        help="Print the schema to stdout instead of saving to file",
    )

    args = parser.parse_args()

    # Determine output file if not specified
    if args.output is None and not args.print:
        args.output = f"openapi.{args.format}"

    try:
        # Export the schema
        if args.print:
            schema = export_openapi_schema(output_file=None, format_type=args.format)
            if args.format == "json":
                print(json.dumps(schema, indent=2, ensure_ascii=False))
            else:
                print(
                    yaml.dump(
                        schema,
                        default_flow_style=False,
                        allow_unicode=True,
                        sort_keys=False,
                    )
                )
        else:
            schema = export_openapi_schema(
                output_file=args.output, format_type=args.format
            )

        # Validate schema if requested
        if args.validate:
            try:
                from openapi_spec_validator import validate_spec

                validate_spec(schema)
                print("✅ Schema validation passed")
            except ImportError:
                print(
                    "⚠️ Validation skipped: install openapi-spec-validator for validation"
                )
            except Exception as e:
                print(f"❌ Schema validation failed: {e}")
                sys.exit(1)

    except KeyboardInterrupt:
        print("\n🛑 Export cancelled by user")
        sys.exit(1)
    except Exception as e:
        print(f"❌ Export failed: {e}")
        sys.exit(1)


if __name__ == "__main__":
    main()