from haystack.document_stores import InMemoryDocumentStore

from haystack.nodes.retriever import TfidfRetriever
from haystack.pipelines import DocumentSearchPipeline, ExtractiveQAPipeline
from haystack.nodes.retriever import EmbeddingRetriever
import pickle
from pprint import pprint
dutch_datset_name = 'Partisan news 2019 (dutch)'
german_datset_name = 'CDU election program 2021'

class ExportableInMemoryDocumentStore(InMemoryDocumentStore):
    """
    Wrapper class around the InMemoryDocumentStore.
    When the application is deployed to Huggingface Spaces there will be no GPU available.
    We need to load pre-calculated data into the InMemoryDocumentStore.
    """
    def export(self, file_name='in_memory_store.pkl'):
        with open(file_name, 'wb') as f:
            pickle.dump(self.indexes, f)

    def load_data(self, file_name='in_memory_store.pkl'):
        with open(file_name, 'rb') as f:
            self.indexes = pickle.load(f)


class SearchEngine():

    def __init__(self, document_store_name_base, document_store_name_adpated,
                 adapted_retriever_path):
        self.document_store = ExportableInMemoryDocumentStore(similarity='cosine')
        self.document_store.load_data(document_store_name_base)

        self.document_store_adapted = ExportableInMemoryDocumentStore(similarity='cosine')
        self.document_store_adapted.load_data(document_store_name_adpated)

        self.retriever = TfidfRetriever(document_store=self.document_store)

        self.base_dense_retriever = EmbeddingRetriever(
            document_store=self.document_store,
            embedding_model='sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
            model_format='sentence_transformers'
        )

        self.fine_tuned_retriever = EmbeddingRetriever(
            document_store=self.document_store_adapted,
            embedding_model=adapted_retriever_path,
            model_format='sentence_transformers'
        )

    def sparse_retrieval(self, query):
        """Sparse retrieval pipeline"""
        scores = self.retriever._calc_scores(query)
        p_retrieval = DocumentSearchPipeline(self.retriever)
        documents = p_retrieval.run(query=query)
        documents['documents'][0].score = list(scores[0].values())[0]
        return documents

    def dense_retrieval(self, query, retriever='base'):
        if retriever == 'base':
            p_retrieval = DocumentSearchPipeline(self.base_dense_retriever)
            return p_retrieval.run(query=query)
        if retriever == 'adapted':
            p_retrieval = DocumentSearchPipeline(self.fine_tuned_retriever)
            return p_retrieval.run(query=query)

    def do_search(self, query):
        sparse_result = self.sparse_retrieval(query)['documents'][0]
        dense_base_result = self.dense_retrieval(query, 'base')['documents'][0]
        dense_adapted_result = self.dense_retrieval(query, 'adapted')['documents'][0]
        return sparse_result, dense_base_result, dense_adapted_result


dutch_search_engine = SearchEngine('dutch-article-idx.pkl', 'dutch-article-idx_adapted.pkl',
                                     'dutch-article-retriever')
german_search_engine = SearchEngine('documentstore_german-election-idx.pkl',
                                        'documentstore_german-election-idx_adapted.pkl',
                                        'adapted-retriever')

def do_search(query, dataset):
    if dataset == german_datset_name:
        return german_search_engine.do_search(query)
    else:
        return dutch_search_engine.do_search(query)

if __name__ == '__main__':
    search_engine = SearchEngine('dutch-article-idx.pkl', 'dutch-article-idx_adapted.pkl',
                                 'dutch-article-retriever')
    query = 'Kindergarten'

    result = search_engine.do_search(query)
    pprint(result)