import gradio as gr
import requests
from datetime import datetime, timezone
from concurrent.futures import ThreadPoolExecutor, as_completed

API_URL = "https://huggingface.co/api/daily_papers"
REPOS_API_URL_TEMPLATE = "https://huggingface.co/api/arxiv/{arxiv_id}/repos"

class PaperManager:
    def __init__(self, papers_per_page=30):
        self.papers_per_page = papers_per_page
        self.current_page = 1
        self.papers = []
        self.total_pages = 1
        self.sort_method = "hot"  # Default sort method
        self.raw_papers = []  # To store fetched data

    def calculate_score(self, paper):
        """
        Calculate the score of a paper based on upvotes and age.
        This mimics the "hotness" algorithm used by platforms like Hacker News.
        """
        upvotes = paper.get('paper', {}).get('upvotes', 0)
        published_at_str = paper.get('publishedAt', datetime.now(timezone.utc).isoformat())
        try:
            published_time = datetime.fromisoformat(published_at_str.replace('Z', '+00:00'))
        except ValueError:
            # If parsing fails, use current time to minimize the impact on sorting
            published_time = datetime.now(timezone.utc)
        
        time_diff = datetime.now(timezone.utc) - published_time
        time_diff_hours = time_diff.total_seconds() / 3600  # Convert time difference to hours

        # Avoid division by zero and apply the hotness formula
        score = upvotes / ((time_diff_hours + 2) ** 1.5)
        return score

    def fetch_repos_counts(self, arxiv_id):
        """
        Fetch the repositories (models, datasets, Spaces) associated with a given arxiv_id.
        Returns a dictionary with counts for each type.
        """
        if not arxiv_id:
            print("Empty arxiv_id provided.")
            return {'models': 0, 'datasets': 0, 'spaces': 0}

        try:
            print(f"Fetching repositories for arxiv_id: {arxiv_id}")
            response = requests.get(REPOS_API_URL_TEMPLATE.format(arxiv_id=arxiv_id))
            response.raise_for_status()
            data = response.json()

            # Debugging: Print the fetched data
            print(f"Repositories data for {arxiv_id}: {data}")

            counts = {'models': 0, 'datasets': 0, 'spaces': 0}
            for repo in data:
                repo_type = repo.get('type', '').strip().lower()
                print(f"Repo type found: {repo_type}")  # Debugging

                if repo_type == 'models':
                    counts['models'] += 1
                elif repo_type == 'datasets':
                    counts['datasets'] += 1
                elif repo_type == 'spaces':
                    counts['spaces'] += 1
                else:
                    print(f"Unknown repo type: {repo_type}")  # Debugging unknown types

            print(f"Counts for {arxiv_id}: {counts}")  # Debugging
            return counts
        except requests.RequestException as e:
            print(f"HTTP error fetching repos for arxiv_id {arxiv_id}: {e}")
            return {'models': 0, 'datasets': 0, 'spaces': 0}
        except ValueError as e:
            print(f"JSON decoding error for arxiv_id {arxiv_id}: {e}")
            return {'models': 0, 'datasets': 0, 'spaces': 0}
        except Exception as e:
            print(f"Unexpected error fetching repos for arxiv_id {arxiv_id}: {e}")
            return {'models': 0, 'datasets': 0, 'spaces': 0}

    def fetch_papers(self):
        try:
            response = requests.get(f"{API_URL}?limit=100")
            response.raise_for_status()
            data = response.json()

            if not data:
                print("No data received from API.")
                return False

            self.raw_papers = data  # Store raw data

            # Debugging: Print some arxiv_ids
            for paper in self.raw_papers[:5]:
                arxiv_id = paper.get('paper', {}).get('arxiv_id', '')
                print(f"Sample arxiv_id: {arxiv_id}")

            # Fetch repos counts concurrently
            with ThreadPoolExecutor(max_workers=20) as executor:
                future_to_paper = {
                    executor.submit(self.fetch_repos_counts, paper.get('paper', {}).get('arxiv_id', '')): paper
                    for paper in self.raw_papers
                }
                for future in as_completed(future_to_paper):
                    paper = future_to_paper[future]
                    counts = future.result()
                    paper['models'] = counts['models']
                    paper['datasets'] = counts['datasets']
                    paper['spaces'] = counts['spaces']

            self.sort_papers()
            self.total_pages = max((len(self.papers) + self.papers_per_page - 1) // self.papers_per_page, 1)
            self.current_page = 1
            return True
        except requests.RequestException as e:
            print(f"Error fetching papers: {e}")
            return False
        except Exception as e:
            print(f"Unexpected error: {e}")
            return False

    def sort_papers(self):
        if self.sort_method == "hot":
            self.papers = sorted(
                self.raw_papers,
                key=lambda x: self.calculate_score(x),
                reverse=True
            )
        elif self.sort_method == "new":
            self.papers = sorted(
                self.raw_papers,
                key=lambda x: x.get('publishedAt', ''),
                reverse=True
            )
        elif self.sort_method == "most_models":
            self.papers = sorted(
                self.raw_papers,
                key=lambda x: x.get('models', 0),
                reverse=True
            )
        elif self.sort_method == "most_datasets":
            self.papers = sorted(
                self.raw_papers,
                key=lambda x: x.get('datasets', 0),
                reverse=True
            )
        elif self.sort_method == "most_spaces":
            self.papers = sorted(
                self.raw_papers,
                key=lambda x: x.get('spaces', 0),
                reverse=True
            )
        else:
            # Default to hot if unknown sort method
            self.papers = sorted(
                self.raw_papers,
                key=lambda x: self.calculate_score(x),
                reverse=True
            )
        print(f"Papers sorted by {self.sort_method}")  # Debug

    def set_sort_method(self, method):
        valid_methods = ["hot", "new", "most_models", "most_datasets", "most_spaces"]
        if method not in valid_methods:
            method = "hot"
        print(f"Setting sort method to: {method}")  # Debug
        self.sort_method = method
        self.sort_papers()
        self.current_page = 1
        return True  # Assume success

    def format_paper(self, paper, rank):
        title = paper.get('title', 'No title')
        paper_id = paper.get('paper', {}).get('id', '')
        url = f"https://huggingface.co/papers/{paper_id}"
        authors = ', '.join([author.get('name', '') for author in paper.get('paper', {}).get('authors', [])]) or 'Unknown'
        upvotes = paper.get('paper', {}).get('upvotes', 0)
        comments = paper.get('numComments', 0)
        published_time_str = paper.get('publishedAt', datetime.now(timezone.utc).isoformat())
        try:
            published_time = datetime.fromisoformat(published_time_str.replace('Z', '+00:00'))
        except ValueError:
            published_time = datetime.now(timezone.utc)
        time_diff = datetime.now(timezone.utc) - published_time
        time_ago_days = time_diff.days
        time_ago = f"{time_ago_days} days ago" if time_ago_days > 0 else "today"

        models = paper.get('models', 0)
        datasets = paper.get('datasets', 0)
        spaces = paper.get('spaces', 0)

        return f"""
        <tr class="athing">
            <td align="right" valign="top" class="title"><span class="rank">{rank}.</span></td>
            <td valign="top" class="title">
                <a href="{url}" class="storylink" target="_blank">{title}</a>
            </td>
        </tr>
        <tr>
            <td colspan="1"></td>
            <td class="subtext">
                <span class="score">{upvotes} upvotes</span><br>
                authors: {authors} | {time_ago} | <a href="#">{comments} comments</a><br>
                Models: {models} | Datasets: {datasets} | Spaces: {spaces}
            </td>
        </tr>
        <tr style="height:5px"></tr>
        """

    def render_papers(self):
        start = (self.current_page - 1) * self.papers_per_page
        end = start + self.papers_per_page
        current_papers = self.papers[start:end]

        if not current_papers:
            return "<div class='no-papers'>No papers available for this page.</div>"

        papers_html = "".join([self.format_paper(paper, idx + start + 1) for idx, paper in enumerate(current_papers)])
        return f"""
        <table border="0" cellpadding="0" cellspacing="0" class="itemlist">
            {papers_html}
        </table>
        """

    def next_page(self):
        if self.current_page < self.total_pages:
            self.current_page += 1
        print(f"Navigated to page {self.current_page}")  # Debug
        return self.render_papers()

    def prev_page(self):
        if self.current_page > 1:
            self.current_page -= 1
        print(f"Navigated to page {self.current_page}")  # Debug
        return self.render_papers()

paper_manager = PaperManager()

def initialize_app():
    if paper_manager.fetch_papers():
        return paper_manager.render_papers()
    else:
        return "<div class='no-papers'>Failed to fetch papers. Please try again later.</div>"

def refresh_papers():
    if paper_manager.fetch_papers():
        return paper_manager.render_papers()
    else:
        return "<div class='no-papers'>Failed to refresh papers. Please try again later.</div>"

def change_sort_method(method):
    method_lower = method.lower().replace(" ", "_")
    print(f"Changing sort method to: {method_lower}")  # Debug
    if paper_manager.set_sort_method(method_lower):
        print("Sort method set successfully.")
        return paper_manager.render_papers()
    else:
        print("Failed to set sort method.")
        return "<div class='no-papers'>Failed to sort papers. Please try again later.</div>"

css = """
body {
    background-color: white;
    font-family: Verdana, Geneva, sans-serif;
    margin: 0;
    padding: 0;
}

a {
    color: #0000ff;
    text-decoration: none;
}

a:visited {
    color: #551A8B;
}

.container {
    width: 85%;
    margin: auto;
}

table {
    width: 100%;
}

.header-table {
    width: 100%;
    background-color: #ff6600;
    padding: 2px 10px;
}

.header-table a {
    color: black;
    font-weight: bold;
    font-size: 14pt;
    text-decoration: none;
}

.itemlist .athing {
    background-color: #f6f6ef;
}

.rank {
    font-size: 14pt;
    color: #828282;
    padding-right: 5px;
}

.storylink {
    font-size: 10pt;
}

.subtext {
    font-size: 8pt;
    color: #828282;
    padding-left: 40px;
}

.subtext a {
    color: #828282;
    text-decoration: none;
}

.no-papers {
    text-align: center;
    color: #828282;
    padding: 1rem;
    font-size: 14pt;
}

@media (max-width: 640px) {
    .header-table a {
        font-size: 12pt;
    }

    .storylink {
        font-size: 9pt;
    }

    .subtext {
        font-size: 7pt;
    }
}

/* Dark mode */
@media (prefers-color-scheme: dark) {
    body {
        background-color: #121212;
        color: #e0e0e0;
    }

    a {
        color: #add8e6;
    }

    a:visited {
        color: #9370db;
    }

    .header-table {
        background-color: #ff6600;
    }

    .header-table a {
        color: black;
    }

    .itemlist .athing {
        background-color: #1e1e1e;
    }

    .rank {
        color: #b0b0b0;
    }

    .subtext {
        color: #b0b0b0;
    }

    .subtext a {
        color: #b0b0b0;
    }

    .no-papers {
        color: #b0b0b0;
    }
}
"""

demo = gr.Blocks(css=css)

with demo:
    with gr.Column(elem_classes=["container"]):
        # Accordion for Submission Instructions
        with gr.Accordion("How to Submit a Paper", open=False):
            gr.Markdown("""
            **Submit the paper to Daily Papers:**
            [https://huggingface.co/papers/submit](https://huggingface.co/papers/submit)

            Once your paper is submitted, it will automatically appear in this demo.
            """)
        # Header without Refresh Button
        with gr.Row():
            gr.HTML("""
            <table border="0" cellpadding="0" cellspacing="0" class="header-table">
                <tr>
                    <td>
                        <span class="pagetop">
                            <b class="hnname"><a href="#">Daily Papers</a></b>
                        </span>
                    </td>
                </tr>
            </table>
            """)
        # Sort Options
        with gr.Row():
            sort_radio = gr.Radio(
                choices=["Hot", "New", "Most Models", "Most Datasets", "Most Spaces"],
                value="Hot",
                label="Sort By",
                interactive=True
            )
        # Paper list
        paper_list = gr.HTML()
        # Navigation Buttons
        with gr.Row():
            prev_button = gr.Button("Prev")
            next_button = gr.Button("Next")

    # Load papers on app start
    demo.load(initialize_app, outputs=[paper_list])

    # Button clicks for pagination
    prev_button.click(paper_manager.prev_page, outputs=[paper_list])
    next_button.click(paper_manager.next_page, outputs=[paper_list])

    # Sort option change
    sort_radio.change(
        fn=change_sort_method,
        inputs=[sort_radio],
        outputs=[paper_list]
    )

demo.launch()