import ast
import json
import re
import warnings
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, cast

import lightgbm as lgb
import numpy as np
import pandas as pd
import shap
import xgboost as xgb
from numpy import ndarray
from scipy.sparse import csr_matrix, hstack
from sklearn.feature_extraction import DictVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer

SEED = 42
ROUNDIND_DIGITS = 6


# Suppress the specific warning from SHAP
warnings.filterwarnings(
    "ignore",
    message="LightGBM binary classifier with TreeExplainer shap values output has changed to a list of ndarray",
    category=UserWarning,
)


class Objective(Enum):
    """
    Enumeration of possible learning objectives for the model.

    Attributes
    ----------
    REG_SQUARED_ERROR : str
        Represents regression using squared error as the loss function.
    REG_ABSOLUTE_ERROR : str
        Represents regression using absolute error (L1 loss) as the loss function.
    REG_POISSON : str
        Represents regression for count data using the Poisson distribution.
    BINARY : str
        Represents binary classification tasks.
    MULTICLASS : str
        Represents multi-class classification tasks.
    """

    REG_SQUARED_ERROR = "reg:squarederror"
    REG_ABSOLUTE_ERROR = "reg:absoluteerror"
    REG_POISSON = "reg:poisson"
    BINARY = "binary"
    MULTICLASS = "multiclass"


XGBOOST_OBJECTIVE_MAPPING = {
    Objective.REG_SQUARED_ERROR: "reg:squarederror",  # Mean Squared Error (MSE)
    Objective.REG_ABSOLUTE_ERROR: "reg:absoluteerror",  # Mean Absolute Error (MAE)
    Objective.REG_POISSON: "count:poisson",  # Poisson regression
    Objective.BINARY: "binary:logistic",  # Binary classification (log loss)
    Objective.MULTICLASS: "multi:softprob",  # Multiclass classification with probabilities
}


LIGHTGBM_OBJECTIVE_MAPPING = {
    Objective.REG_SQUARED_ERROR: "regression",  # Mean Squared Error (MSE)
    Objective.REG_ABSOLUTE_ERROR: "regression_l1",  # Mean Absolute Error (MAE)
    Objective.REG_POISSON: "poisson",  # Poisson regression
    Objective.BINARY: "binary",  # Binary classification (log loss)
    Objective.MULTICLASS: "multimulticlass",  # Multiclass classification with probabilities
}


class Metric(Enum):
    RMSE = "root_mean_squared_error"
    MAE = "mean_absolute_error"
    POISSON = "poisson_nloglik"
    LOGLOSS = "logloss"
    MLOGLOSS = "multi_logloss"
    AUC = "auc"
    MAUC = "multi_auc"


DEFAULT_EVAL_METRIC = {
    Objective.REG_SQUARED_ERROR: Metric.RMSE,
    Objective.REG_ABSOLUTE_ERROR: Metric.MAE,
    Objective.REG_POISSON: Metric.POISSON,
    Objective.BINARY: Metric.LOGLOSS,
    Objective.MULTICLASS: Metric.MLOGLOSS,
}


LIGHTGBM_METRICS_MAPPING = {
    Metric.RMSE: "rmse",
    Metric.MAE: "l1",
    Metric.POISSON: "poisson",
    Metric.LOGLOSS: "binary_logloss",
    Metric.MLOGLOSS: "multi_logloss",
    Metric.AUC: "auc",
    Metric.MAUC: "auc",
}


XGBOOST_METRICS_MAPPING = {
    Metric.RMSE: "rmse",
    Metric.MAE: "mae",
    Metric.POISSON: "poisson-nloglik",
    Metric.LOGLOSS: "logloss",
    Metric.MLOGLOSS: "mlogloss",
    Metric.AUC: "auc",
    Metric.MAUC: "auc_mu",
}


def to_dict_series(series: pd.Series) -> pd.Series:
    """Convert string to dictionary. Necesarry when feature table is read from a parquet file."""
    try:
        return series.apply(lambda v: json.loads(v) if v else v)
    except TypeError:
        return series


def to_list_series(series: pd.Series) -> pd.Series:
    """Converts a pandas Series of stringified lists into a Series of actual lists."""
    try:
        return series.apply(lambda v: ast.literal_eval(v) if isinstance(v, str) else v)
    except TypeError:
        return series


class CategoricalColumnTransformer:
    def __init__(self, threshold: int = 5) -> None:
        """
        Initialize the transformer.

        Parameters
        ----------
        threshold: int=5
            The minimum count a category should have. Categories with counts below this will be assigned -1.
        """
        self.threshold = threshold
        self.value_mapping: Dict[str, int] = {}
        self.fitted = False

    def fit(self, series: pd.Series) -> None:
        """
        Fit the transformer to the data.

        Parameters
        ----------
        series: pd.Series
            The pandas Series to fit.
        """
        # Get value counts for each category
        value_counts = series.value_counts(dropna=False)

        # Sort categories by count, assign a unique rank even if counts are the same
        sorted_categories = value_counts.index.tolist()
        sorted_categories_counts = value_counts.tolist()

        rank = 0
        for category, count in zip(sorted_categories, sorted_categories_counts):
            if pd.isna(category):  # Handle missing values
                self.value_mapping[category] = -2
            elif count < self.threshold:
                self.value_mapping[category] = -1
            else:
                self.value_mapping[category] = rank
                rank += 1

        self.fitted = True

    def transform(self, series: pd.Series) -> pd.Series:
        """
        Transform the data based on the fitted transformer.

        Parameters
        ----------
        series: pd.Series
            The pandas Series to transform.

        Returns
        -------
        A pandas Series with transformed values.
        """
        if not self.fitted:
            raise RuntimeError("The transformer has not been fitted yet.")

        # Map the series values using the value mapping
        return series.map(self.value_mapping).fillna(-1)

    def fit_transform(self, series: pd.Series) -> pd.Series:
        """
        Fit the transformer and transform the data in one step.

        Parameters
        ----------
        series: pd.Series
            The pandas Series to transform.

        Returns
        -------
        A pandas Series with transformed values.
        """
        self.fit(series)
        return self.transform(series)


class BasePipeline:
    def __init__(
        self,
        objective: Objective,
        eval_metric: Optional[Metric] = None,
        small_count_threshold: int = 5,
        seed: int = SEED,
    ) -> None:
        """
        Initialize the model.

        Parameters
        ----------
        objective: Objective
            The model objective function.
                reg:squarederror
                    Represents regression using squared error as the loss function.
                reg:absoluteerror
                    Represents regression using absolute error (L1 loss) as the loss function.
                reg:poisson
                    Represents regression for count data using the Poisson distribution.
                binary
                    Represents binary classification tasks.
                multiclass
                    Represents multi-class classification tasks.
        eval_metric: Optional[str], optional
            The evaluation metric for the model. Auto-selected based on the objective if not provided.
                root_mean_squared_error
                mean_absolute_error
                poisson_nloglik
                logloss
                multi_logloss
                auc
                multi_auc
        small_count_threshold: int=5
            Threshold to define small category counts
        seed: int, optional
            The seed for random number generation for reproducibility (default is 42).
        """
        self.objective = objective

        if eval_metric is None:
            self.eval_metric = DEFAULT_EVAL_METRIC[self.objective]
        else:
            self.eval_metric = eval_metric
        self.small_count_threshold = small_count_threshold
        self.seed = seed

        # Set during fit
        self.model: Optional[Union[xgb.Booster, lgb.Booster]] = None
        self.numeric_features: List[str] = []
        self.categorical_features: List[str] = []
        self.dictionary_features: List[str] = []
        self.embedding_features: List[str] = []
        self.text_features: List[str] = []
        self.use_sparse = True

        # Set during preprocessing
        self.value_mapping: Dict[Any, int] = {}
        self.min_values: Dict[str, float] = {}
        self.dict_vectorizers: Dict[str, DictVectorizer] = {}
        self.categorical_transformers: Dict[str, CategoricalColumnTransformer] = {}
        self.tfidf_vectorizers: Dict[str, TfidfVectorizer] = {}

        # Mapping from original features to XGBoost feature columns
        self.feature_mapping: Dict[str, List[str]] = {}

    def preprocess_numeric(self, df: pd.DataFrame, fit: bool = True) -> ndarray[np.float64, Any]:
        """
        Preprocess numeric features by imputing missing values based on the minimum value in the column.

        Parameters
        ----------
        df: pd.DataFrame
            The input DataFrame containing numeric features to preprocess.
        fit: bool, optional
            If True, fit the preprocessing by calculating the minimum values for imputation.
            If False, apply the pre-fitted minimum values for imputation.

        Returns
        -------
        A matrix
            A matrix representation of the numeric features with imputed values.
        """
        assert self.numeric_features
        matrices = []

        for col in self.numeric_features:
            if fit:
                self.min_values[col] = df[col].min()

            transformed = df[col].fillna(self.min_values[col] - 1)
            cond = transformed.isin([np.inf, -np.inf])
            if sum(cond):
                transformed = transformed.replace([np.inf, -np.inf], self.min_values[col] - 1)
            matrices.append(transformed.values.reshape(-1, 1))

        return np.hstack(matrices)

    def preprocess_categorical(
        self, df: pd.DataFrame, fit: bool = True
    ) -> ndarray[np.float64, Any]:
        """
        Preprocess categorical features by encoding categories based on their frequency.
        Categories with small counts are assigned a special value, and missing values are handled separately.

        Parameters
        ----------
        df: pd.DataFrame
            The input DataFrame containing categorical features to preprocess.
        fit: bool, optional
            If True, fit the preprocessing by learning category encodings based on frequency.
            If False, apply the pre-fitted encodings to the data.

        Returns
        -------
        A matrix
            A matrix representation of the encoded categorical features.
        """
        matrices = []
        for col in self.categorical_features:
            if fit:
                transformer = CategoricalColumnTransformer(threshold=self.small_count_threshold)
                transformed = transformer.fit_transform(df[col])
                self.categorical_transformers[col] = transformer
            else:
                transformed = self.categorical_transformers[col].transform(df[col])

            matrices.append(transformed.values.reshape(-1, 1))

        return np.hstack(matrices)

    def preprocess_dictionary(
        self, df: pd.DataFrame, fit: bool = True
    ) -> Tuple[csr_matrix, List[str]]:
        """
        Preprocess dictionary features by converting them into feature vectors using a dictionary vectorizer.

        Parameters
        ----------
        df: pd.DataFrame
            The input DataFrame containing dictionary features to preprocess.
        fit: bool, optional
            If True, fit the preprocessing by learning feature transformations using the dictionary vectorizer.
            If False, apply the pre-fitted vectorizer to the data.

        Returns
        -------
        A sparse matrix
            A sparse matrix representation of the dictionary features transformed into feature vectors.
        column_names
            List of column names based on the dict vectorizers.
        """
        dict_sparse_matrices: List[csr_matrix] = []
        column_names: List[str] = []

        for col in self.dictionary_features:
            dict_col = to_dict_series(df[col]).apply(lambda x: {} if pd.isna(x) else x)
            if fit:
                vec = DictVectorizer(sparse=True)
                dict_sparse_matrix = vec.fit_transform(dict_col)
                feature_names = [
                    f"{col}_{re.sub(r'[^a-zA-Z0-9_]', '_', name)}"
                    for name in vec.get_feature_names_out()
                ]
                self.dict_vectorizers[col] = vec
                self.feature_mapping[col] = feature_names
            else:
                dict_sparse_matrix = self.dict_vectorizers[col].transform(dict_col)

            dict_sparse_matrices.append(dict_sparse_matrix)

            column_names.extend(self.feature_mapping[col])

        return hstack(
            dict_sparse_matrices
        ) if dict_sparse_matrices else csr_matrix([]), column_names

    def preprocess_embedding(
        self, df: pd.DataFrame, fit: bool = True
    ) -> Tuple[ndarray[np.float64, Any], List[str]]:
        """
        Preprocess embedding features by flattening each embedding vector into multiple features.

        Parameters
        ----------
        df: pd.DataFrame
            The input DataFrame containing embedding features to preprocess.
        fit: bool, optional
            If True, fit feature_names.

        Returns
        -------
        A matrix
            A matrix representation of the flattened embedding vectors.
        column_names
            List of column names.
        """
        embedding_matrices = []
        column_names: List[str] = []

        for col in self.embedding_features:
            embedding_col = to_list_series(df[col])
            missing_flag = embedding_col.isnull()
            if fit:
                if sum(~missing_flag):
                    embedding_length = next(len(x) for x in embedding_col if x is not None)
                    self.feature_mapping[col] = [f"{col}_{i}" for i in range(embedding_length)]
                else:
                    self.feature_mapping[col] = []
            if self.feature_mapping[col]:
                if sum(missing_flag):
                    # Impute null values with the list of zeros
                    zeros_list = [0] * len(self.feature_mapping[col])
                    feature_imputed = embedding_col.copy()
                    feature_imputed.loc[missing_flag] = [zeros_list] * missing_flag.sum()
                else:
                    feature_imputed = embedding_col
                embeddings = pd.DataFrame(feature_imputed.tolist(), index=df.index)
                embedding_matrices.append(embeddings.values)
                column_names.extend(self.feature_mapping[col])

        return np.hstack(embedding_matrices) if embedding_matrices else np.array([]), column_names

    def preprocess_text(self, df: pd.DataFrame, fit: bool = True) -> Tuple[csr_matrix, List[str]]:
        """
        Preprocess text features by applying TF-IDF vectorization with a custom tokenizer.

        Parameters
        ----------
        df: pd.DataFrame
            The input DataFrame containing text features to preprocess.
        fit: bool, optional
            If True, fit the TF-IDF vectorizer on the text data.
            If False, apply the pre-fitted vectorizer to transform the text data.

        Returns
        -------
        A sparse matrix
            A sparse matrix representation of the text features transformed into TF-IDF features.
        column_names
            List of column names based on the TfidfVectorizer.
        """
        sparse_matrices = []
        column_names: List[str] = []
        for col in self.text_features:
            text_data = df[col].fillna("").astype(str)
            if fit:
                # Initialize TF-IDF vectorizer with custom tokenizer
                tfidf_vectorizer = TfidfVectorizer(
                    ngram_range=(1, 2),  # Adjust the N-gram range as needed
                    max_features=20000,  # Limit the maximum number of features
                    norm="l2",  # Apply L2 normalization to TF-IDF values
                    stop_words="english",  # Remove common English stop words
                )
                ngram_matrix = tfidf_vectorizer.fit_transform(text_data)
                self.tfidf_vectorizers[col] = tfidf_vectorizer
                # Add feature names from the TfidfVectorizer
                feature_names = [
                    f"{col}_{re.sub(r'[^a-zA-Z0-9_]', '_', name)}"
                    for name in tfidf_vectorizer.get_feature_names_out()
                ]
                # Update feature mapping
                self.feature_mapping[col] = feature_names
            else:
                # Use the fitted TF-IDF vectorizer
                ngram_matrix = self.tfidf_vectorizers[col].transform(text_data)

            sparse_matrices.append(ngram_matrix)
            column_names.extend(self.feature_mapping[col])

        return hstack(sparse_matrices) if sparse_matrices else csr_matrix([]), column_names

    def preprocess(
        self, df: pd.DataFrame, fit: bool = True
    ) -> Tuple[Union[csr_matrix, ndarray[np.float64, Any]], List[str]]:
        """
        Preprocess all specified features (numeric, categorical, dictionary, embedding, and text)
        and combine them into a single matrix.

        Parameters
        ----------
        df: pd.DataFrame
            The input DataFrame containing features of different types to preprocess.
        fit: bool, optional
            If True, fit the preprocessors for all column types.
            If False, apply the pre-fitted preprocessors to the data.

        Returns
        -------
        A matrix (sparse or dense)
            A combined matrix representation of all preprocessed features.
        column_names
            List of column names for the matrix.
        """
        matrices = []
        column_names: List[str] = []

        # Process numeric features into a sparse matrix
        if self.numeric_features:
            numeric_matrix = self.preprocess_numeric(df, fit=fit)
            if self.use_sparse:
                numeric_matrix = csr_matrix(numeric_matrix)
            matrices.append(numeric_matrix)
            column_names.extend(self.numeric_features)

            if fit:
                # Update feature mapping
                for feature in self.numeric_features:
                    self.feature_mapping[feature] = [feature]

        # Process categorical features into a sparse matrix
        if self.categorical_features:
            categorical_matrix = self.preprocess_categorical(df, fit=fit)
            if self.use_sparse:
                categorical_matrix = csr_matrix(categorical_matrix)
            matrices.append(categorical_matrix)
            column_names.extend(self.categorical_features)

            if fit:
                # Update feature mapping
                for feature in self.categorical_features:
                    self.feature_mapping[feature] = [feature]

        # Process dictionary features into a sparse matrix
        if self.dictionary_features:
            dictionary_matrix, dict_column_names = self.preprocess_dictionary(df, fit=fit)
            if dictionary_matrix.shape[1]:
                matrices.append(dictionary_matrix)
                column_names.extend(dict_column_names)

        # Process embedding features into a sparse matrix
        if self.embedding_features:
            embedding_matrix, embedding_column_names = self.preprocess_embedding(df, fit=fit)
            if self.use_sparse:
                embedding_matrix = csr_matrix(embedding_matrix)
            matrices.append(embedding_matrix)
            column_names.extend(embedding_column_names)

        # Process text features into a sparse matrix
        if self.text_features:
            text_matrix, text_column_names = self.preprocess_text(df, fit=fit)
            if text_matrix.shape[1]:
                matrices.append(text_matrix)
                column_names.extend(text_column_names)

        if self.use_sparse:
            # Concatenate all sparse matrices
            full_matrix = hstack(matrices)
            full_matrix.data = np.round(full_matrix.data, ROUNDIND_DIGITS)
        else:
            # Concatenate all dense matrices
            full_matrix = np.hstack(matrices)
            full_matrix = np.round(full_matrix, ROUNDIND_DIGITS)

        assert full_matrix.shape[1] == len(column_names)
        return full_matrix, column_names

    def train(
        self,
        df_train: pd.DataFrame,
        df_test: Optional[pd.DataFrame],
        y_train: pd.Series,
        y_test: Optional[pd.Series],
        feature_types: Dict[
            str, Literal["numeric", "categorical", "dictionary", "embedding", "text", "others"]
        ],
    ) -> None:
        """
        Categorize features based on their data types

        Parameters
        ----------
        df_train: pd.DataFrame
            The input DataFrame containing features to be used for model training.
        df_test: pd.DataFrame
            The input DataFrame containing features to be used for model test.
        y_train: pd.Series
            The target variable for model training.
        y_test: pd.Series
            The target variable for model test.
        feature_types: Dict[str, Literal["numeric", "categorical", "dictionary", "embedding", "text", "others"]]
            The types of the features.

        Returns
        -------
        None
        """

        _ = df_test
        _ = y_train
        _ = y_test

        self.numeric_features = [
            feature_name
            for feature_name in df_train.columns
            if feature_types[feature_name] == "numeric"
        ]
        self.categorical_features = [
            feature_name
            for feature_name in df_train.columns
            if feature_types[feature_name] == "categorical"
        ]
        self.dictionary_features = [
            feature_name
            for feature_name in df_train.columns
            if feature_types[feature_name] == "dictionary"
        ]
        self.embedding_features = [
            feature_name
            for feature_name in df_train.columns
            if feature_types[feature_name] == "embedding"
        ]
        self.text_features = [
            feature_name
            for feature_name in df_train.columns
            if feature_types[feature_name] == "text"
        ]
        self.use_sparse = len(self.dictionary_features + self.text_features) > 0

    def compute_shap(
        self, df: pd.DataFrame, chunk_size: Optional[int] = None
    ) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """
        Compute SHAP summary incrementally by processing the input data in chunks if needed.
        """
        if self.model is None:
            raise ValueError("Pipeline has not been trained yet.")
        if self.objective == Objective.MULTICLASS:
            raise ValueError("Multi-class not supported yet.")

        explainer = shap.TreeExplainer(self.model)

        # Preprocess the entire DataFrame to get consistent column names
        _, column_names = self.preprocess(df.head(1), fit=False)

        # Approximate memory usage of shap matrix
        if chunk_size is None:
            approx_ram = df.shape[0] * len(column_names) * 8
            # Determine whether chunking is necessary
            if approx_ram > 1e9:
                chunk_size = int(1e9 / approx_ram * df.shape[0])

        # Precompute mapping indices for feature aggregation
        mapping_matrix = np.zeros((len(df.columns), len(column_names)))
        for i, column in enumerate(df.columns):
            if column in self.feature_mapping:
                related_columns = self.feature_mapping[column]
                indices = [column_names.index(col) for col in related_columns]
                mapping_matrix[i, indices] = 1

        if chunk_size:
            # Process data in chunks
            shap_summary_chunks = []
            for start in range(0, len(df), chunk_size):
                end = min(start + chunk_size, len(df))

                # Slice the DataFrame for the current chunk
                df_chunk = df.iloc[start:end]

                # Preprocess the chunk
                X_processed_chunk, _ = self.preprocess(df_chunk, fit=False)

                # Compute SHAP values for the current chunk
                shap_values_chunk = explainer.shap_values(X_processed_chunk)
                if isinstance(shap_values_chunk, list):
                    if len(shap_values_chunk) == 2:  # LightGBM use the last class
                        shap_values_chunk = shap_values_chunk[-1]
                    else:
                        raise ValueError("Multi-class not supported yet.")

                # Aggregate SHAP values using matrix multiplication
                shap_summary_chunk = shap_values_chunk @ mapping_matrix.T
                shap_summary_chunks.append(shap_summary_chunk)

            # Concatenate all chunks and create a DataFrame
            shap_summary_matrix = np.vstack(shap_summary_chunks)
        else:
            # Process the entire DataFrame in one go
            X_processed, _ = self.preprocess(df, fit=False)
            shap_values = explainer.shap_values(X_processed)
            if isinstance(shap_values, list):
                if len(shap_values) == 2:  # LightGBM use the last class
                    shap_values = shap_values[-1]
                else:
                    raise ValueError("Multi-class not supported yet.")

            # Aggregate SHAP values using matrix multiplication
            shap_summary_matrix = shap_values @ mapping_matrix.T

        # Ensure shap_summary.columns matches df.columns
        shap_summary = pd.DataFrame(shap_summary_matrix, columns=df.columns, index=df.index)

        importance_scores = shap_summary.abs().mean(axis=0)

        # Create a DataFrame for feature importance
        feature_importance = pd.DataFrame({
            "feature": importance_scores.index,
            "importance": importance_scores.values,
        })

        # Sort features by importance
        feature_importance = feature_importance.sort_values(
            by="importance", ascending=False
        ).reset_index(drop=True)

        feature_importance["cumulative_importance"] = feature_importance["importance"].cumsum()
        total_importance = feature_importance["importance"].sum()
        feature_importance["cumulative_importance_percent"] = (
            feature_importance["cumulative_importance"] / total_importance
        )

        return shap_summary, feature_importance

    @abstractmethod
    def get_best_score(self) -> float:
        """Must be implemented by subclasses."""
        pass


class XGBoostPipeline(BasePipeline):
    def __init__(
        self,
        objective: Objective,
        eval_metric: Optional[Metric] = None,
        num_boost_round: int = 20000,
        early_stopping_rounds: int = 50,
        learning_rate: float = 0.005,
        max_depth: int = 6,
        subsample: float = 0.8,
        colsample_bytree: float = 0.5,
        gamma: float = 0.025,
        small_count_threshold: int = 5,
        seed: int = SEED,
    ) -> None:
        """
        Initialize the model.

        Parameters
        ----------
        objective : str, optional
            The model's objective function, defining the learning task and corresponding loss function. Options include:
                - **reg:squarederror**: Regression using squared error as the loss function.
                - **reg:absoluteerror**: Regression using absolute error (L1 loss) as the loss function.
                - **reg:poisson**: Regression for count data using the Poisson distribution.
                - **binary**: Binary classification tasks.
                - **multiclass**: Multi-class classification tasks.
        eval_metric : Optional[str], optional
            The evaluation metric for the model, used to assess performance. Automatically selected based on the objective if not provided. Common options include:
                - **root_mean_squared_error**: Root mean squared error.
                - **mean_absolute_error**: Mean absolute error.
                - **poisson_nloglik**: Negative log-likelihood for Poisson regression.
                - **logloss**: Logarithmic loss for binary classification.
                - **multi_logloss**: Logarithmic loss for multi-class classification.
                - **auc**: Area under the ROC curve.
                - **multi_auc**: Area under the ROC curve for multi-class classification.
        num_boost_round : int, optional
            The maximum number of boosting rounds. Default is 20000.
        early_stopping_rounds : Optional[int], optional
            Number of rounds without improvement to trigger early stopping. Default is 50.
        learning_rate : float, optional
            The step size shrinkage parameter for gradient descent. Smaller values slow down learning but can improve performance. Default is 0.005.
        max_depth : int, optional
            Maximum depth of the decision trees. Controls model complexity and risk of overfitting. Default is 6.
        subsample : float, optional
            Fraction of training samples to use for fitting each tree. Helps prevent overfitting. Default is 0.8.
        colsample_bytree : float, optional
            Fraction of features to consider when constructing each tree. Helps reduce overfitting. Default is 0.5.
        gamma : float, optional
            Minimum loss reduction required to make a split. A higher value results in fewer splits, reducing overfitting. Default is 0.025.
        small_count_threshold : int, optional
            Threshold for defining small category counts in categorical features. Categories with counts below this value may be grouped together or treated differently. Default is 5.
        seed : int, optional
            Random seed for reproducibility. Default is 42.
        """
        super().__init__(
            objective=objective,
            eval_metric=eval_metric,
            small_count_threshold=small_count_threshold,
            seed=seed,
        )
        self.num_boost_round = num_boost_round
        self.early_stopping_rounds = early_stopping_rounds
        self.learning_rate = learning_rate
        self.max_depth = max_depth
        self.subsample = subsample
        self.colsample_bytree = colsample_bytree
        self.gamma = gamma

    def train(
        self,
        df_train: pd.DataFrame,
        df_test: Optional[pd.DataFrame],
        y_train: pd.Series,
        y_test: Optional[pd.Series],
        feature_types: Dict[
            str, Literal["numeric", "categorical", "dictionary", "embedding", "text", "others"]
        ],
    ) -> None:
        """
        Fit the XGBoost model by preprocessing the input data.
        Do early stopping if df_test and y_test are present.

        Parameters
        ----------
        df_train: pd.DataFrame
            The input DataFrame containing features to be used for model training.
        df_test: pd.DataFrame
            The input DataFrame containing features to be used for model test.
        y_train: pd.Series
            The target variable for model training.
        y_test: pd.Series
            The target variable for model test.
        feature_types: Dict[str, Literal["numeric", "categorical", "dictionary", "embedding", "text", "others"]]
            The types of the features.

        Returns
        -------
        None
        """

        super().train(
            df_train=df_train,
            df_test=df_test,
            y_train=y_train,
            y_test=y_test,
            feature_types=feature_types,
        )

        params = {
            "objective": XGBOOST_OBJECTIVE_MAPPING[self.objective],
            "learning_rate": self.learning_rate,
            "max_depth": self.max_depth,
            "subsample": self.subsample,
            "colsample_bytree": self.colsample_bytree,
            "gamma": self.gamma,
            "seed": self.seed,  # Fixed seed for XGBoost,
            "verbosity": 1,
        }

        # Add num_class for multi-class classification
        if self.objective == Objective.MULTICLASS:
            params["num_class"] = y_train.nunique()

        X_train_processed, column_names = self.preprocess(df_train, fit=True)
        dtrain = xgb.DMatrix(X_train_processed, label=y_train, feature_names=column_names)

        if df_test is not None:
            params["eval_metric"] = XGBOOST_METRICS_MAPPING[self.eval_metric]
            X_test_processed, _ = self.preprocess(df_test, fit=False)
            dtest = xgb.DMatrix(X_test_processed, label=y_test, feature_names=column_names)
            self.model = xgb.train(
                params,
                dtrain,
                evals=[(dtest, "eval")],
                num_boost_round=self.num_boost_round,
                early_stopping_rounds=self.early_stopping_rounds,
                verbose_eval=False,
            )
        else:
            self.model = xgb.train(
                params,
                dtrain,
                num_boost_round=self.num_boost_round,
            )

    def predict(self, df: pd.DataFrame) -> ndarray[np.float64, Any]:
        """
        Make predictions using the pre-trained XGBoost model on the input data.

        Parameters
        ----------
        df: pd.DataFrame
            The input DataFrame containing features to be used for prediction.

        Returns
        -------
        ndarray
            An array of predicted values.
        """
        if self.model is None:
            raise ValueError("Pipeline has not been trained yet.")

        X_processed, column_names = self.preprocess(df, fit=False)
        dmatrix = xgb.DMatrix(X_processed, feature_names=column_names)
        return cast(ndarray[np.float64, Any], self.model.predict(dmatrix))

    def get_best_score(self) -> float:
        """Get best score for the model"""
        if self.model is None:
            raise ValueError("Pipeline has not been trained yet.")
        return cast(float, self.model.best_score)


class LightGBMPipeline(BasePipeline):
    def __init__(
        self,
        objective: Objective,
        eval_metric: Optional[Metric] = None,  # Auto-determined based on task if not provided
        num_boost_round: int = 20000,
        early_stopping_rounds: int = 50,
        learning_rate: float = 0.005,
        max_depth: int = 10,
        num_leaves: int = 54,
        subsample: float = 0.8,
        colsample_bytree: float = 0.5,
        min_split_gain: float = 0.025,
        reg_alpha: float = 0.5,
        reg_lambda: float = 0.5,
        small_count_threshold: int = 5,
        seed: int = SEED,  # Seed for reproducibility
    ) -> None:
        """
        Initialize the model.

        Parameters
        ----------
        objective : str, optional
            The model's objective function, defining the learning task and corresponding loss function. Options include:
                - **reg:squarederror**: Regression using squared error as the loss function.
                - **reg:absoluteerror**: Regression using absolute error (L1 loss) as the loss function.
                - **reg:poisson**: Regression for count data using the Poisson distribution.
                - **binary**: Binary classification tasks.
                - **multiclass**: Multi-class classification tasks.
        eval_metric : Optional[str], optional
            The evaluation metric for the model, used to assess performance. Automatically selected based on the objective if not provided. Common options include:
                - **root_mean_squared_error**: Root mean squared error.
                - **mean_absolute_error**: Mean absolute error.
                - **poisson_nloglik**: Negative log-likelihood for Poisson regression.
                - **logloss**: Logarithmic loss for binary classification.
                - **multi_logloss**: Logarithmic loss for multi-class classification.
                - **auc**: Area under the ROC curve.
                - **multi_auc**: Area under the ROC curve for multi-class classification.
        num_boost_round : int, optional
            The maximum number of boosting rounds. Default is 20000.
        early_stopping_rounds : Optional[int], optional
            Number of rounds without improvement to trigger early stopping. Default is 50.
        learning_rate : float, optional
            The step size shrinkage parameter for gradient descent. Smaller values slow down learning but can improve performance. Default is 0.005.
        max_depth : int, optional
            Maximum depth of the decision trees. Controls model complexity and risk of overfitting. Use `-1` for no limit. Default is 10.
        num_leaves : int, optional
            Maximum number of leaves for each tree. A higher value increases model capacity and complexity. Default is 54.
        subsample : float, optional
            Fraction of training samples to use for fitting each tree. Helps prevent overfitting. Default is 0.8.
        colsample_bytree : float, optional
            Fraction of features to consider when constructing each tree. Helps reduce overfitting. Default is 0.5.
        min_split_gain : float, optional
            Minimum loss reduction required to make a split, similar to `gamma` in XGBoost. Helps control overfitting. Default is 0.025.
        reg_alpha : float, optional
            L1 regularization term on weights. Increases sparsity of the model. Default is 0.5.
        reg_lambda : float, optional
            L2 regularization term on weights. Helps prevent overfitting. Default is 0.5.
        small_count_threshold : int, optional
            Threshold for defining small category counts in categorical features. Categories with counts below this value may be grouped together or treated differently. Default is 5.
        seed : int, optional
            Random seed for reproducibility. Default is 42.
        """
        super().__init__(
            objective=objective,
            eval_metric=eval_metric,
            small_count_threshold=small_count_threshold,
            seed=seed,
        )
        self.num_boost_round = num_boost_round
        self.early_stopping_rounds = early_stopping_rounds
        self.learning_rate = learning_rate
        self.max_depth = max_depth
        self.num_leaves = num_leaves
        self.subsample = subsample
        self.colsample_bytree = colsample_bytree
        self.min_split_gain = min_split_gain
        self.reg_alpha = reg_alpha
        self.reg_lambda = reg_lambda

    def train(
        self,
        df_train: pd.DataFrame,
        df_test: Optional[pd.DataFrame],
        y_train: pd.Series,
        y_test: Optional[pd.Series],
        feature_types: Dict[
            str, Literal["numeric", "categorical", "dictionary", "embedding", "text", "others"]
        ],
    ) -> None:
        """
        Fit the LightGBM model by preprocessing the input data.
        Do early stopping if df_test and y_test are present.

        Parameters
        ----------
        df_train: pd.DataFrame
            The input DataFrame containing features to be used for model training.
        df_test: pd.DataFrame
            The input DataFrame containing features to be used for model test.
        y_train: pd.Series
            The target variable for model training.
        y_test: pd.Series
            The target variable for model test.
        feature_types: Dict[str, Literal["numeric", "categorical", "dictionary", "embedding", "text", "others"]]
            The types of the features.

        Returns
        -------
        None
        """

        super().train(
            df_train=df_train,
            df_test=df_test,
            y_train=y_train,
            y_test=y_test,
            feature_types=feature_types,
        )

        params = {
            "objective": LIGHTGBM_OBJECTIVE_MAPPING[self.objective],
            "learning_rate": self.learning_rate,
            "num_leaves": self.num_leaves,
            "max_depth": self.max_depth,
            "subsample": self.subsample,
            "colsample_bytree": self.colsample_bytree,
            "reg_alpha": self.reg_alpha,
            "reg_lambda": self.reg_lambda,
            "seed": self.seed,
            "verbose": -1,  # Suppress all messages
        }

        # Add num_class for multi-class classification
        if self.objective == Objective.MULTICLASS:
            params["num_class"] = y_train.nunique()

        X_train_processed, column_names = self.preprocess(df_train, fit=True)
        train_data = lgb.Dataset(X_train_processed, label=y_train, feature_name=column_names)

        if df_test is not None:
            params["metric"] = LIGHTGBM_METRICS_MAPPING[self.eval_metric]
            params["early_stopping_round"] = self.early_stopping_rounds
            X_test_processed, _ = self.preprocess(df_test, fit=False)
            test_data = lgb.Dataset(
                X_test_processed, label=y_test, feature_name=column_names, reference=train_data
            )
            self.model = lgb.train(
                params,
                train_data,
                valid_sets=[test_data],
                num_boost_round=self.num_boost_round,
            )
        else:
            self.model = lgb.train(
                params,
                train_data,
                num_boost_round=self.num_boost_round,
            )

    def predict(self, df: pd.DataFrame) -> ndarray[np.float64, Any]:
        """
        Make predictions using the pre-trained LightGBM model on the input data.
        """
        if self.model is None:
            raise ValueError("Pipeline has not been trained yet.")
        assert isinstance(self.model, lgb.Booster)

        X_processed, _ = self.preprocess(df, fit=False)
        return cast(ndarray[np.float64, Any], self.model.predict(X_processed))

    def get_best_score(self) -> float:
        """Get best score for the model"""
        if self.model is None:
            raise ValueError("Pipeline has not been trained yet.")
        lgbm_metric = LIGHTGBM_METRICS_MAPPING[self.eval_metric]
        best_score = cast(Dict[str, Dict[str, float]], self.model.best_score)
        if "valid_0" not in best_score or lgbm_metric not in best_score["valid_0"]:
            raise KeyError(f"Metric '{lgbm_metric}' not found in best_score.")
        return best_score["valid_0"][lgbm_metric]
