import unittest
from unittest.mock import patch

import pandas as pd

import src.backend.evaluate_model as evaluate_model
import src.envs as envs


class TestEvaluator(unittest.TestCase):

    def setUp(self):
        self.model_name = 'test_model'
        self.revision = 'test_revision'
        self.precision = 'test_precision'
        self.batch_size = 10
        self.device = 'test_device'
        self.no_cache = False
        self.limit = 10

    @patch('src.backend.evaluate_model.SummaryGenerator')
    @patch('src.backend.evaluate_model.EvaluationModel')
    def test_evaluator_initialization(self, mock_eval_model, mock_summary_generator):
        evaluator = evaluate_model.Evaluator(self.model_name, self.revision,
                                            self.precision, self.batch_size,
                                            self.device, self.no_cache, self.limit)

        mock_summary_generator.assert_called_once_with(self.model_name, self.revision)
        mock_eval_model.assert_called_once_with(envs.HEM_PATH)
        self.assertEqual(evaluator.model, self.model_name)

    @patch('src.backend.evaluate_model.EvaluationModel')
    @patch('src.backend.evaluate_model.SummaryGenerator')
    def test_evaluator_initialization_error(self, mock_summary_generator, mock_eval_model):
        mock_eval_model.side_effect = Exception('test_exception')
        with self.assertRaises(Exception):
            evaluate_model.Evaluator(self.model_name, self.revision,
                                    self.precision, self.batch_size,
                                    self.device, self.no_cache, self.limit)

    @patch('src.backend.evaluate_model.SummaryGenerator')
    @patch('src.backend.evaluate_model.EvaluationModel')
    @patch('src.backend.evaluate_model.pd.read_csv')
    @patch('src.backend.util.format_results')
    def test_evaluate_method(self, mock_format_results, mock_read_csv, mock_eval_model,
                            mock_summary_generator):
        evaluator = evaluate_model.Evaluator(self.model_name, self.revision,
                                            self.precision, self.batch_size,
                                            self.device, self.no_cache, self.limit)

        # Mock setup
        mock_format_results.return_value = {'test': 'result'}
        mock_read_csv.return_value = pd.DataFrame({'column1': ['data1', 'data2']})
        mock_summary_generator.return_value.generate_summaries.return_value = pd.DataFrame({'column1': ['summary1', 'summary2']})
        mock_summary_generator.return_value.avg_length = 100
        mock_summary_generator.return_value.answer_rate = 1.0
        mock_summary_generator.return_value.error_rate = 0.0
        mock_eval_model.return_value.compute_accuracy.return_value = 1.0
        mock_eval_model.return_value.hallucination_rate = 0.0
        mock_eval_model.return_value.evaluate_hallucination.return_value = [0.5]

        # Method call and assertions
        results = evaluator.evaluate()
        mock_format_results.assert_called_once_with(model_name=self.model_name,
                                                    revision=self.revision,
                                                    precision=self.precision,
                                                    accuracy=1.0, hallucination_rate=0.0,
                                                    answer_rate=1.0, avg_summary_len=100,
                                                    error_rate=0.0)
        mock_read_csv.assert_called_once_with(envs.SOURCE_PATH)

    @patch('src.backend.evaluate_model.SummaryGenerator')
    @patch('src.backend.evaluate_model.EvaluationModel')
    @patch('src.backend.evaluate_model.pd.read_csv')
    def test_evaluate_with_file_not_found(self, mock_read_csv, mock_eval_model,
                                        mock_summary_generator):
        mock_read_csv.side_effect = FileNotFoundError('test_exception')
        evaluator = evaluate_model.Evaluator(self.model_name, self.revision,
                                            self.precision, self.batch_size,
                                            self.device, self.no_cache, self.limit)

        with self.assertRaises(FileNotFoundError):
            evaluator.evaluate()


if __name__ == '__main__':
    unittest.main()