import graphviz
import json
from tempfile import NamedTemporaryFile
import os

def generate_process_flow_diagram(json_input: str, output_format: str) -> str:
    """
    Generates a Process Flow Diagram (Flowchart) from JSON input.

    Args:
        json_input (str): A JSON string describing the process flow structure.
                          It must follow the Expected JSON Format Example below.

        output_format (str): The output format for the generated diagram.
                            Supported formats: "png" or "svg"

    Expected JSON Format Example:
    {
      "start_node": "Start Inference Request",
      "nodes": [
        {
          "id": "user_input",
          "label": "Receive User Input (Data)",
          "type": "io"
        },
        {
          "id": "preprocess_data",
          "label": "Preprocess Data",
          "type": "process"
        },
        {
          "id": "validate_data",
          "label": "Validate Data Format/Type",
          "type": "decision"
        },
        {
          "id": "data_valid_yes",
          "label": "Data Valid?",
          "type": "decision"
        },
        {
          "id": "load_model",
          "label": "Load AI Model (if not cached)",
          "type": "process"
        },
        {
          "id": "run_inference",
          "label": "Run AI Model Inference",
          "type": "process"
        },
        {
          "id": "postprocess_output",
          "label": "Postprocess Model Output",
          "type": "process"
        },
        {
          "id": "send_response",
          "label": "Send Response to User",
          "type": "io"
        },
        {
          "id": "log_error",
          "label": "Log Error & Notify User",
          "type": "process"
        },
        {
          "id": "end_inference_process",
          "label": "End Inference Process",
          "type": "end"
        }
      ],
      "connections": [
        {"from": "start_node", "to": "user_input", "label": "Request"},
        {"from": "user_input", "to": "preprocess_data", "label": "Data Received"},
        {"from": "preprocess_data", "to": "validate_data", "label": "Cleaned"},
        {"from": "validate_data", "to": "data_valid_yes", "label": "Check"},
        {"from": "data_valid_yes", "to": "load_model", "label": "Yes"},
        {"from": "data_valid_yes", "to": "log_error", "label": "No"},
        {"from": "load_model", "to": "run_inference", "label": "Model Ready"},
        {"from": "run_inference", "to": "postprocess_output", "label": "Output Generated"},
        {"from": "postprocess_output", "to": "send_response", "label": "Ready"},
        {"from": "send_response", "to": "end_inference_process", "label": "Response Sent"},
        {"from": "log_error", "to": "end_inference_process", "label": "Error Handled"}
      ]
    }

    Returns:
        str: The filepath to the generated image file.
    """
    try:
        if not json_input.strip():
            return "Error: Empty input"
            
        data = json.loads(json_input)
        
        if 'start_node' not in data or 'nodes' not in data or 'connections' not in data:
            raise ValueError("Missing required fields: 'start_node', 'nodes', or 'connections'")

        node_shapes = {
            "process": "box",
            "decision": "diamond",
            "start": "oval",
            "end": "oval",
            "io": "parallelogram",
            "document": "note",
            "default": "box"
        }

        node_colors = {
            "process": "#BEBEBE",
            "decision": "#FFF9C4",
            "start": "#A8E6CF",
            "end": "#FFB3BA",
            "io": "#B8D4F1",
            "document": "#F0F8FF",
            "default": "#BEBEBE"
        }

        dot = graphviz.Digraph(
            name='ProcessFlowDiagram',
            format='png',
            graph_attr={
                'rankdir': 'TB',
                'splines': 'ortho',
                'bgcolor': 'white',
                'pad': '0.5',
                'nodesep': '0.6',
                'ranksep': '0.8'
            }
        )
        
        all_defined_nodes = {node['id']: node for node in data['nodes']}
        
        start_node_id = data['start_node']
        dot.node(
            start_node_id,
            start_node_id,
            shape=node_shapes['start'],
            style='filled,rounded',
            fillcolor=node_colors['start'],
            fontcolor='black',
            fontsize='14'
        )

        for node_id, node_info in all_defined_nodes.items():
            if node_id == start_node_id:
                continue

            node_type = node_info.get("type", "default")
            shape = node_shapes.get(node_type, "box")
            color = node_colors.get(node_type, node_colors["default"])
            node_label = node_info['label']

            dot.node(
                node_id,
                node_label,
                shape=shape,
                style='filled,rounded',
                fillcolor=color,
                fontcolor='black',
                fontsize='14'
            )

        for connection in data['connections']:
            dot.edge(
                connection['from'],
                connection['to'],
                label=connection.get('label', ''),
                color='#4a4a4a',
                fontcolor='#4a4a4a',
                fontsize='10'
            )
        
        with NamedTemporaryFile(delete=False, suffix=f'.{output_format}') as tmp:
            dot.render(tmp.name, format=output_format, cleanup=True)
            return f"{tmp.name}.{output_format}"

    except json.JSONDecodeError:
        return "Error: Invalid JSON format"
    except Exception as e:
        return f"Error: {str(e)}"