import os
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator


# ================= 1. INPEFA / PEFA 算法（最小输出版） =================
class InpefaCalculator:
    """
    保留两套核心曲线：
    1) 新版局部滑窗：输出 PEFA / raw-INPEFA
    2) 老版全局尺度：legacy-global-INPEFA
    """

    @staticmethod
    def _safe_fill_nan(data: np.ndarray) -> np.ndarray:
        data = np.asarray(data, dtype=np.float64)
        if data.size == 0:
            return data
        mean_val = np.nanmean(data)
        if np.isnan(mean_val):
            return np.zeros_like(data, dtype=np.float64)
        return np.nan_to_num(data, nan=mean_val)

    @staticmethod
    def _burg_ar(x: np.ndarray, order: int):
        x = np.asarray(x, dtype=np.float64)
        N = len(x)
        if N < 3:
            return np.array([1.0], dtype=np.float64), 0.0

        order = int(max(1, min(order, N - 2)))
        a = np.array([1.0], dtype=np.float64)
        f = x.copy()
        b = x.copy()
        rho = np.dot(x, x) / max(N, 1)

        for _ in range(order):
            if len(f) < 2 or len(b) < 2:
                break

            f1 = f[1:]
            b1 = b[:-1]

            den = np.dot(f1, f1) + np.dot(b1, b1)
            mu = 0.0 if den == 0 else (-2.0 * np.dot(f1, b1) / den)
            mu = np.clip(mu, -0.999999, 0.999999)

            current_a = a.copy()
            a = np.append(a, 0.0)
            if len(current_a) > 1:
                a[1:-1] = current_a[1:] + mu * current_a[1:][::-1]
            a[-1] = mu

            f_old = f1.copy()
            b_old = b1.copy()
            f = f_old + mu * b_old
            b = b_old + mu * f_old
            rho *= (1.0 - mu ** 2)

        return a, rho

    @staticmethod
    def _predict_next_from_history(history: np.ndarray, ar_order: int) -> float:
        history = InpefaCalculator._safe_fill_nan(history)
        n = len(history)
        if n < 3:
            return float(history[-1]) if n > 0 else 0.0

        ar_order = int(max(1, min(ar_order, n - 2)))
        mean_val = np.mean(history)
        centered = history - mean_val

        coeffs, _ = InpefaCalculator._burg_ar(centered, ar_order)
        if len(coeffs) <= 1:
            return float(history[-1])

        past = centered[-ar_order:][::-1]
        pred_centered = -np.dot(coeffs[1:ar_order + 1], past)
        return float(pred_centered + mean_val)

    @staticmethod
    def _local_pefa_one_direction(data: np.ndarray, window_length: int, ar_order: int) -> np.ndarray:
        data = InpefaCalculator._safe_fill_nan(data)
        n = len(data)

        if n < 5:
            return np.zeros(n, dtype=np.float64)

        window_length = int(max(3, min(window_length, n - 1)))
        ar_order = int(max(1, min(ar_order, window_length - 1)))

        pefa = np.zeros(n, dtype=np.float64)
        for i in range(window_length, n):
            history = data[i - window_length:i]
            pred = InpefaCalculator._predict_next_from_history(history, ar_order)
            pefa[i] = data[i] - pred
        return pefa

    @staticmethod
    def calculate_pefa_raw(
        data: np.ndarray,
        window_length: int,
        ar_order: int,
        direction: str = "reverse",
        center_pefa_before_integrate: bool = True,
        standardize_pefa: bool = True,
    ):
        data = InpefaCalculator._safe_fill_nan(data)
        n = len(data)
        if n < max(5, window_length + 1):
            zeros = np.zeros_like(data, dtype=np.float64)
            return zeros, zeros

        direction = str(direction).lower().strip()
        if direction == "forward":
            pefa = InpefaCalculator._local_pefa_one_direction(data, window_length, ar_order)
        elif direction == "reverse":
            pefa_rev = InpefaCalculator._local_pefa_one_direction(data[::-1], window_length, ar_order)
            pefa = pefa_rev[::-1]
        elif direction == "bidirectional":
            pefa_f = InpefaCalculator._local_pefa_one_direction(data, window_length, ar_order)
            pefa_r = InpefaCalculator._local_pefa_one_direction(data[::-1], window_length, ar_order)[::-1]
            pefa = 0.5 * (pefa_f + pefa_r)
        else:
            raise ValueError("direction 只能是 forward / reverse / bidirectional")

        pefa_proc = pefa.copy()
        if center_pefa_before_integrate:
            pefa_proc = pefa_proc - np.mean(pefa_proc)
        if standardize_pefa:
            std = np.std(pefa_proc)
            if std > 1e-12:
                pefa_proc = pefa_proc / std

        if direction == "reverse":
            raw_inpefa = np.cumsum(pefa_proc[::-1])[::-1]
        else:
            raw_inpefa = np.cumsum(pefa_proc)

        return pefa, raw_inpefa

    @staticmethod
    def _legacy_global_one_direction(data: np.ndarray, ar_order: int) -> np.ndarray:
        clean_data = InpefaCalculator._safe_fill_nan(data)
        n = len(clean_data)
        if n < 3:
            return np.zeros_like(clean_data, dtype=np.float64)

        order = int(max(1, min(ar_order, n - 2)))
        if n < order + 1:
            return np.zeros_like(clean_data, dtype=np.float64)

        mean_val = np.mean(clean_data)
        centered_data = clean_data - mean_val
        pef_filter, _ = InpefaCalculator._burg_ar(centered_data, order)

        valid_error = np.convolve(centered_data, pef_filter, mode='valid')
        prediction_error = np.zeros(n, dtype=np.float64)
        usable = min(len(valid_error), n - order)
        if usable > 0:
            prediction_error[order:order + usable] = valid_error[:usable]
        return np.cumsum(prediction_error)

    @staticmethod
    def calculate_legacy_global_inpefa(data: np.ndarray, ar_order: int, direction: str = "reverse", normalize_11: bool = False) -> np.ndarray:
        """
        计算全局尺度 legacy-global-INPEFA。
        新增 normalize_11 选项，如果为 True，会将最终结果的数值域缩放到 [-1, 1]。
        """
        data = InpefaCalculator._safe_fill_nan(data)
        direction = str(direction).lower().strip()

        if direction == "forward":
            res = InpefaCalculator._legacy_global_one_direction(data, ar_order)
        elif direction == "reverse":
            legacy_rev = InpefaCalculator._legacy_global_one_direction(data[::-1], ar_order)
            res = legacy_rev[::-1]
        elif direction == "bidirectional":
            legacy_f = InpefaCalculator._legacy_global_one_direction(data, ar_order)
            legacy_r = InpefaCalculator._legacy_global_one_direction(data[::-1], ar_order)[::-1]
            res = 0.5 * (legacy_f + legacy_r)
        else:
            raise ValueError("direction 只能是 forward / reverse / bidirectional")

        # +++ 新增：将计算结果标准化到 [-1, 1] +++
        if normalize_11:
            valid_mask = np.isfinite(res)
            if np.any(valid_mask):
                val_min = np.min(res[valid_mask])
                val_max = np.max(res[valid_mask])
                if val_max - val_min > 1e-12: # 避免除以0
                    res[valid_mask] = 2.0 * (res[valid_mask] - val_min) / (val_max - val_min) - 1.0
                else:
                    res[valid_mask] = 0.0 # 如果全都是同一个常数，直接置0

        return res


# ================= 2. 全局绘图设置 =================
plt.rcParams['font.sans-serif'] = ['SimHei', 'WenQuanYi Micro Hei', 'SimSun', 'Arial']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['pdf.fonttype'] = 42


# ================= 3. 通用工具函数 =================
def load_csv_safe(csv_path):
    encodings = ['utf-8-sig', 'utf-8', 'gbk', 'gb18030']
    errors = []
    for enc in encodings:
        try:
            return pd.read_csv(csv_path, encoding=enc)
        except Exception as e:
            errors.append(f"[{enc}] {repr(e)}")
    for enc in encodings:
        try:
            return pd.read_csv(csv_path, encoding=enc, sep=None, engine='python')
        except Exception as e:
            errors.append(f"[{enc}, auto-sep] {repr(e)}")
    print(f"❌ 读取文件失败: {csv_path}")
    return None


def load_well_list(list_path, drop_duplicates=False):
    if not os.path.exists(list_path):
        return []

    df_list = load_csv_safe(list_path)
    if df_list is not None and not df_list.empty:
        candidates = []
        first_col = df_list.iloc[:, 0].dropna().astype(str).str.strip().tolist()
        candidates.extend(first_col)
        if len(df_list.columns) >= 1:
            first_col_name = str(df_list.columns[0]).strip()
            if first_col_name and first_col_name.lower() not in ['井号', 'well', 'well_name', 'name']:
                candidates.insert(0, first_col_name)

        cleaned = [x for x in candidates if x and x.lower() not in ['井号', 'well', 'well_name', 'name']]
        if cleaned:
            return list(dict.fromkeys(cleaned)) if drop_duplicates else cleaned

    return []


def normalize_column_names(df: pd.DataFrame) -> pd.DataFrame:
    rename_map = {}
    for col in df.columns:
        c = str(col).strip()
        if c in ['深度', 'DEPTH', 'Depth', 'depth', '深度(m)', '深度（m）']:
            rename_map[col] = '深度'
        elif c in ['GR', 'gr', '自然伽马', '自然伽马GR', 'GR(API)', 'GR（API）']:
            rename_map[col] = 'GR'
        elif c in ['层号', '层位', '地层', 'formation', 'Formation']:
            rename_map[col] = '层号'
        elif c in ['顶深', '顶界深度', 'Top', 'top', 'TOP']:
            rename_map[col] = '顶深'
        elif c in ['底深', '底界深度', 'Bottom', 'bottom', 'BOTTOM']:
            rename_map[col] = '底深'
    return df.rename(columns=rename_map)


def prepare_log_dataframe(df_log: pd.DataFrame) -> pd.DataFrame:
    df_log = normalize_column_names(df_log.copy())
    if '深度' not in df_log.columns or 'GR' not in df_log.columns:
        return df_log

    df_log['深度'] = pd.to_numeric(df_log['深度'], errors='coerce')
    df_log['GR'] = pd.to_numeric(df_log['GR'], errors='coerce')
    df_log = df_log.dropna(subset=['深度', 'GR'])
    df_log = df_log[df_log['GR'] > 0]
    df_log = df_log.sort_values('深度')
    df_log = df_log.drop_duplicates(subset=['深度'], keep='first')
    df_log = df_log.reset_index(drop=True)
    return df_log


def _clean_name_for_match(text: str) -> str:
    return re.sub(r'[\s_\-（）()【】\[\]·,.，。]+', '', str(text).strip().lower())


def find_matched_file(directory, well_name):
    if not os.path.exists(directory):
        return None

    well_key = _clean_name_for_match(well_name)
    exact_matches, suffix_matches, contain_matches = [], [], []
    for fname in os.listdir(directory):
        stem = os.path.splitext(fname)[0]
        stem_key = _clean_name_for_match(stem)
        full_path = os.path.join(directory, fname)
        if stem_key == well_key:
            exact_matches.append(full_path)
        elif stem_key.endswith(well_key):
            suffix_matches.append(full_path)
        elif well_key in stem_key:
            contain_matches.append(full_path)

    if exact_matches:
        return exact_matches[0]
    if suffix_matches:
        return suffix_matches[0]
    if contain_matches:
        return contain_matches[0]
    return None


def extract_formations(formation_csv_path, target_names=None):
    # 修改：已移除“*********************”，仅保留“*********************”和“*********************”
    if target_names is None:
        target_names = ['*********************', '*********************']

    df_fm = load_csv_safe(formation_csv_path)
    if df_fm is None:
        return []

    df_fm = normalize_column_names(df_fm.copy())
    required_cols = ['层号', '顶深', '底深']
    if not all(col in df_fm.columns for col in required_cols):
        return []

    df_fm['层号'] = df_fm['层号'].astype(str).str.strip()
    df_fm['顶深'] = pd.to_numeric(df_fm['顶深'], errors='coerce')
    df_fm['底深'] = pd.to_numeric(df_fm['底深'], errors='coerce')

    df_target = df_fm[df_fm['层号'].isin(target_names)].copy()
    if df_target.empty:
        return []

    df_target = (
        df_target
        .dropna(subset=['顶深', '底深'])
        .groupby('层号', as_index=False)
        .agg({'顶深': 'min', '底深': 'max'})
    )

    df_target['sort_cat'] = pd.Categorical(df_target['层号'], categories=target_names, ordered=True)
    df_target = df_target.sort_values('sort_cat')

    formations = []
    for _, row in df_target.iterrows():
        top = float(row['顶深'])
        bottom = float(row['底深'])
        if top >= bottom:
            continue
        formations.append({'name': row['层号'], 'top': top, 'bottom': bottom})
    return formations


def _slice_df_by_formations(df, formations, expand=20.0):
    global_top = min(f['top'] for f in formations)
    global_bottom = max(f['bottom'] for f in formations)
    mask = (df['深度'] >= global_top - expand) & (df['深度'] <= global_bottom + expand)
    return df.loc[mask].copy().sort_values('深度').reset_index(drop=True), global_top, global_bottom


# ================= 新增：单独计算某一个地层的 raw-INPEFA =================
def compute_single_formation_raw(
    df_full: pd.DataFrame,
    formation: dict,
    window_length: int,
    ar_order: int,
    direction: str
) -> np.ndarray:
    """
    只对单个地层段计算 raw-INPEFA：
    - 本层段内有值
    - 其他深度位置全部为 NaN
    """
    out_curve = np.full(len(df_full), np.nan, dtype=np.float64)

    if formation is None:
        return out_curve

    mask = (df_full['深度'] >= formation['top']) & (df_full['深度'] <= formation['bottom'])
    if not mask.any():
        return out_curve

    seg_gr = df_full.loc[mask, 'GR'].to_numpy(dtype=np.float64)
    if seg_gr.size == 0:
        return out_curve

    _, seg_raw = InpefaCalculator.calculate_pefa_raw(
        seg_gr,
        window_length=window_length,
        ar_order=ar_order,
        direction=direction,
        center_pefa_before_integrate=True,
        standardize_pefa=True,
    )

    out_curve[mask.to_numpy()] = seg_raw
    return out_curve


def build_export_dataframe(
    df_full: pd.DataFrame,
    formations,
    window_length_raw: int,
    window_length_grouped: int,
    ar_order: int,
    direction: str
) -> pd.DataFrame:
    gr = df_full['GR'].to_numpy(dtype=np.float64)
    depth = df_full['深度'].to_numpy(dtype=np.float64)

    # 1. 整体曲线：长窗口
    pefa, raw_inpefa = InpefaCalculator.calculate_pefa_raw(
        gr,
        window_length=window_length_raw,
        ar_order=ar_order,
        direction=direction,
        center_pefa_before_integrate=True,
        standardize_pefa=True,
    )

    # +++ 开启 normalize_11=True，将浅蓝色曲线归一化到 [-1, 1] +++
    legacy_global = InpefaCalculator.calculate_legacy_global_inpefa(
        gr,
        ar_order=ar_order,
        direction=direction,
        normalize_11=True
    )

    # 2. 分层曲线：短窗口，拆成两根独立图道
    formation_map = {f['name']: f for f in formations}

    target1_raw = compute_single_formation_raw(
        df_full,
        formation=formation_map.get('*********************'),
        window_length=window_length_grouped,
        ar_order=ar_order,
        direction=direction
    )

    target2_raw = compute_single_formation_raw(
        df_full,
        formation=formation_map.get('*********************'),
        window_length=window_length_grouped,
        ar_order=ar_order,
        direction=direction
    )

    return pd.DataFrame({
        'Depth': depth,
        'GR': gr,
        'raw_INPEFA': raw_inpefa,
        'legacy_global_INPEFA': legacy_global,
        'raw_INPEFA_*********************1': target1_raw,
        'raw_INPEFA_*********************2': target2_raw,
    })


def add_formation_lines(axes, formations, x_text_anchor=None):
    for form in formations:
        for ax in axes:
            ax.axhline(form['top'], color='gray', linestyle='--', linewidth=0.7, alpha=0.6)
            ax.axhline(form['bottom'], color='gray', linestyle='--', linewidth=0.7, alpha=0.6)
        if x_text_anchor is not None:
            mid = 0.5 * (form['top'] + form['bottom'])
            axes[0].text(x_text_anchor, mid, form['name'], va='center', ha='left', fontsize=9, color='dimgray')


def save_bundle_plot(df_export: pd.DataFrame, formations, output_dir: str, well_name: str):
    os.makedirs(output_dir, exist_ok=True)

    depth = df_export['Depth'].to_numpy(dtype=np.float64)
    tracks = [
        ('GR', 'black', 'GR'),
        ('raw_INPEFA', '#9467BD', 'raw-INPEFA'),
        ('legacy_global_INPEFA', '#17BECF', 'legacy-global-INPEFA'),
        ('raw_INPEFA_*********************1', '#2CA02C', '********************* raw-INPEFA'),
        ('raw_INPEFA_*********************2', '#FF7F0E', '********************* raw-INPEFA'),
    ]

    fig_width = max(15, 3.2 * len(tracks))
    fig, axes = plt.subplots(1, len(tracks), figsize=(fig_width, 11), dpi=300, sharey=True)
    if len(tracks) == 1:
        axes = [axes]

    for i, (col, color, label_text) in enumerate(tracks):
        ax = axes[i]
        values = df_export[col].to_numpy(dtype=np.float64)

        ax.plot(values, depth, color=color, linewidth=1.0, zorder=3)

        ax.set_xlabel(label_text, color=color, fontsize=10)
        ax.xaxis.tick_top()
        ax.xaxis.set_label_position('top')
        ax.tick_params(axis='x', colors=color, labelsize=8)
        ax.grid(True, which='major', linestyle='-', linewidth=0.4, color='lightgray')
        ax.yaxis.set_major_locator(MultipleLocator(100))
        ax.yaxis.set_minor_locator(MultipleLocator(25))
        ax.tick_params(axis='y', which='minor', length=4, direction='in')
        ax.tick_params(axis='y', which='major', length=8, direction='in')

        if i == 0:
            ax.set_ylabel('深度 (m)', fontsize=11)
        else:
            ax.tick_params(axis='y', labelleft=False)

        finite = values[np.isfinite(values)]
        if finite.size > 0:
            mean_val = np.mean(finite)

            ax.axvline(x=0, color='gray', linestyle='-.', linewidth=0.8, alpha=0.7, zorder=1)
            ax.axvline(x=mean_val, color='red', linestyle='--', linewidth=1.0, alpha=0.8, zorder=2)

            lo, hi = np.min(finite), np.max(finite)
            if lo == hi:
                pad = max(1.0, abs(lo) * 0.05 + 1.0)
            else:
                pad = max((hi - lo) * 0.06, 1e-6)
            ax.set_xlim(lo - pad, hi + pad)
        else:
            ax.text(
                0.5, 0.5, '无数据',
                transform=ax.transAxes,
                ha='center', va='center',
                fontsize=10, color='gray'
            )

    axes[0].invert_yaxis()

    x_anchor = axes[0].get_xlim()[0]
    add_formation_lines(axes, formations, x_text_anchor=x_anchor)

    fig.suptitle(
        f'{well_name} 曲线总对比图\nGR / raw-INPEFA / legacy-global-INPEFA / *********************raw-INPEFA / *********************raw-INPEFA',
        fontsize=13,
        fontweight='bold',
        y=0.98
    )
    fig.subplots_adjust(wspace=0.22, top=0.90)

    out_png = os.path.join(output_dir, f'{well_name}_CurveBundle.png')
    fig.savefig(out_png, bbox_inches='tight', dpi=300)
    plt.close(fig)
    print(f'  ✅ 图片保存至 -> {out_png}')
    return out_png


# ================= 4. 主程序 =================
if __name__ == '__main__':
    # 两套窗口长度配置
    WINDOW_LENGTH_RAW = 122       # 全井尺度，偏大窗口
    WINDOW_LENGTH_GROUPED = 62   # 地层局部尺度，偏小窗口

    AR_ORDER = 480
    DIRECTION = 'reverse'
    FULL_EXPAND = 20.0

    LIST_CSV = r'/home/*********************/DATA/*********************/list.csv'
    FM_DIR = r'/home/*********************/DATA/*********************/地层单位_CSV'
    LOG_DIR = r'/home/*********************/DATA/*********************/CSV_Output'
    BASE_OUTPUT_DIR = rf'/home/*********************/DATA/*********************/AR_{AR_ORDER}Output_Plots_Mini{WINDOW_LENGTH_RAW}_{WINDOW_LENGTH_GROUPED}'

    if not os.path.exists(LIST_CSV):
        print(f'找不到井列表文件: {LIST_CSV}')
        raise SystemExit(1)

    well_names = load_well_list(LIST_CSV, drop_duplicates=True)
    if not well_names:
        print(f'井列表文件为空或读取失败: {LIST_CSV}')
        raise SystemExit(1)

    print(f'📌 共找到 {len(well_names)} 口井需要处理。')

    for well in well_names:
        print(f'\n[{well}] ----------------------------')
        try:
            fm_path = find_matched_file(FM_DIR, well)
            if not fm_path:
                print(f'  ❌ 跳过: 找不到对应的地层CSV ({well})')
                continue

            formations = extract_formations(fm_path)
            if not formations:
                # 修改：更新了警告提示文本
                print('  ⚠️ 警告: 该井不包含 *********************/*********************，跳过。')
                continue
            print(f"  📖 匹配到地层: {[f['name'] for f in formations]}")

            log_path = find_matched_file(LOG_DIR, well)
            if not log_path:
                print(f'  ❌ 跳过: 找不到对应的测井曲线CSV ({well})')
                continue

            df_log = load_csv_safe(log_path)
            if df_log is None:
                continue
            df_log = prepare_log_dataframe(df_log)
            if '深度' not in df_log.columns or 'GR' not in df_log.columns:
                print(f"  ❌ 跳过: 数据中缺少 '深度' 或 'GR' 列。当前表格拥有的列为: {list(df_log.columns)}")
                continue
            if df_log.empty:
                print('  ❌ 跳过: 清洗后无有效的 GR 测井数据。')
                continue

            df_full, global_top, global_bottom = _slice_df_by_formations(df_log, formations, expand=FULL_EXPAND)
            if df_full.empty:
                print('  ❌ 跳过: 目标区段内无有效数据。')
                continue

            well_output_dir = os.path.join(BASE_OUTPUT_DIR, well)
            os.makedirs(well_output_dir, exist_ok=True)

            df_export = build_export_dataframe(
                df_full,
                formations,
                window_length_raw=WINDOW_LENGTH_RAW,
                window_length_grouped=WINDOW_LENGTH_GROUPED,
                ar_order=AR_ORDER,
                direction=DIRECTION
            )

            csv_path = os.path.join(well_output_dir, f'{well}_CurveBundle.csv')
            df_export.to_csv(csv_path, index=False, encoding='utf-8-sig')
            print(f'  ✅ CSV 保存至 -> {csv_path}')

            save_bundle_plot(df_export, formations, well_output_dir, well)

        except Exception as e:
            print(f'  ❌ 处理该井时发生未知异常，已跳过。报错信息: {e}')
            continue