import abc
from typing import List, Union

from numpy.typing import NDArray
from sentence_transformers import SentenceTransformer

from .type_aliases import ENCODER_DEVICE_TYPE


class Encoder(abc.ABC):
    @abc.abstractmethod
    def encode(self, prediction: List[str]) -> NDArray:
        """
            Abstract method to encode a list of sentences into sentence embeddings.

            Args:
                prediction (List[str]): List of sentences to encode.

            Returns:
                NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).

            Raises:
                NotImplementedError: If the method is not implemented in the subclass.
        """
        raise NotImplementedError("Method 'encode' must be implemented in subclass.")


class SBertEncoder(Encoder):
    def __init__(self, model: SentenceTransformer, device: ENCODER_DEVICE_TYPE, batch_size: int, verbose: bool):
        """
        Initialize SBertEncoder instance.

        Args:
            model (SentenceTransformer): The Sentence Transformer model instance to use for encoding.
            device (Union[str, int, List[Union[str, int]]]): Device specification for encoding
            batch_size (int): Batch size for encoding.
            verbose (bool): Whether to print verbose information during encoding.
        """
        self.model = model
        self.device = device
        self.batch_size = batch_size
        self.verbose = verbose

    def encode(self, prediction: List[str]) -> NDArray:
        """
           Encode a list of sentences into sentence embeddings.

           Args:
               prediction (List[str]): List of sentences to encode.

           Returns:
               NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
        """

        # SBert output is always Batch x Dim
        if isinstance(self.device, list):
            # Use multiprocess encoding for list of devices
            pool = self.model.start_multi_process_pool(target_devices=self.device)
            embeddings = self.model.encode_multi_process(prediction, pool=pool, batch_size=self.batch_size)
            self.model.stop_multi_process_pool(pool)
        else:
            # Single device encoding
            embeddings = self.model.encode(
                prediction,
                device=self.device,
                batch_size=self.batch_size,
            )

        return embeddings


def get_encoder(
        sbert_model: SentenceTransformer,
        device: ENCODER_DEVICE_TYPE,
        batch_size: int,
        verbose: bool,
) -> Encoder:
    """
    Get an instance of SBertEncoder using the provided parameters.

    Args:
        sbert_model (SentenceTransformer): An instance of SentenceTransformer model to use for encoding.
        device (Union[str, int, List[Union[str, int]]): Device specification for the encoder
            (e.g., "cuda", 0 for GPU, "cpu").
        batch_size (int): Batch size to use for encoding.
        verbose (bool): Whether to print verbose information during encoding.

    Returns:
        SBertEncoder: Instance of the selected encoder based on the model_name.

    Example:
        >>> model_name = "paraphrase-distilroberta-base-v1"
        >>> sbert_model = get_sbert_encoder(model_name)
        >>> device = get_gpu("cuda")
        >>> batch_size = 32
        >>> verbose = True
        >>> encoder = get_encoder(sbert_model, device, batch_size, verbose)
    """
    encoder = SBertEncoder(sbert_model, device, batch_size, verbose)
    return encoder


def get_sbert_encoder(model_name: str) -> SentenceTransformer:
    """
    Get an instance of SentenceTransformer encoder based on the specified model name.

    Args:
        model_name (str): Name of the model to instantiate. You can use any model on Huggingface/SentenceTransformer
            that is supported by SentenceTransformer.

    Returns:
        SentenceTransformer: Instance of the selected encoder based on the model_name.

    Raises:
        EnvironmentError: If an unsupported model_name is provided.
        RuntimeError: If there's an issue during instantiation of the encoder.
    """

    try:
        encoder = SentenceTransformer(model_name, trust_remote_code=True)
    except EnvironmentError as err:
        raise EnvironmentError(str(err)) from None
    except Exception as err:
        raise RuntimeError(str(err)) from None

    return encoder