"""
functions/sampling.py
---------------------
Pure sampling functions.
All accept a plain Python list[dict] as input and return a list[dict] result.
No shared state, no file I/O — stateless by design.
"""
import math
import random
from itertools import product as iproduct
from typing import Any


# ── helpers ──────────────────────────────────────────────────────────────────

def _to_records(data: list[dict]) -> list[dict]:
    return [dict(row) for row in data]


def _seed(random_state: int | None) -> int:
    return random_state if random_state is not None else 42


# ── 1. Simple Random Sampling ─────────────────────────────────────────────────

def simple_random_sampling(
    data: list[dict],
    n: int,
    random_state: int | None = None,
) -> dict:
    """
    Draw `n` rows at random without replacement.
    Returns sampled rows + basic info.
    """
    rows = _to_records(data)
    population = len(rows)
    n = max(0, min(n, population))

    rng = random.Random(_seed(random_state))
    sampled = rng.sample(rows, n)

    return {
        "sampled": sampled,
        "info": {
            "method": "simple_random",
            "population_size": population,
            "sample_size": len(sampled),
            "random_state": _seed(random_state),
        },
    }


# ── 2. Systematic (Interval-Based) Sampling ───────────────────────────────────

def systematic_sampling(
    data: list[dict],
    n: int,
    random_state: int | None = None,
) -> dict:
    """
    Interval/systematic sampling.
    Interval k = N // n.  Start position is random within [0, k).
    """
    rows = _to_records(data)
    N = len(rows)
    n = max(0, min(n, N))

    if N == 0 or n == 0:
        return {
            "sampled": [],
            "info": {"method": "systematic", "population_size": N, "sample_size": 0},
        }

    k = max(1, N // n)
    rng = random.Random(_seed(random_state))
    start = rng.randint(0, k - 1)

    indices = []
    idx = start
    while len(indices) < n and idx < N:
        indices.append(idx)
        idx += k

    # Wrap-around if still short
    if len(indices) < n:
        idx = idx % N
        seen = set(indices)
        while len(indices) < n and idx not in seen:
            indices.append(idx)
            seen.add(idx)
            idx = (idx + k) % N

    sampled = [rows[i] for i in sorted(indices)]

    return {
        "sampled": sampled,
        "info": {
            "method": "systematic",
            "population_size": N,
            "sample_size": len(sampled),
            "interval_k": k,
            "start_index_0based": start,
            "start_index_1based": start + 1,
        },
    }


# ── 3. Stratified Sampling ────────────────────────────────────────────────────

def _build_strata_cells(strata_list: list[dict]) -> list[dict]:
    """
    Build cartesian product cells from strata definition.

    strata_list: [
        {"column": "Gender", "filters": [{"value": "Male", "pct": 60}, {"value": "Female", "pct": 40}]},
        {"column": "Region", "filters": [{"value": "North", "pct": 50}, {"value": "South", "pct": 50}]},
    ]
    """
    if not strata_list:
        return []

    cols_filters = [(s["column"], s.get("filters", [])) for s in strata_list]

    if len(cols_filters) == 1:
        col, filters = cols_filters[0]
        return [
            {
                "label": f"{col}: {f['value']}",
                "filters": {col: f["value"]},
                "pct": float(f.get("pct", 0)),
            }
            for f in filters
        ]

    all_filter_lists = [[(col, f) for f in filters] for col, filters in cols_filters]
    cells = []
    for combo in iproduct(*all_filter_lists):
        label_parts, filter_dict = [], {}
        pct = 100.0
        for col, f in combo:
            label_parts.append(f"{col}: {f['value']}")
            filter_dict[col] = f["value"]
            pct = pct * float(f.get("pct", 0)) / 100.0
        cells.append({"label": " | ".join(label_parts), "filters": filter_dict, "pct": pct})
    return cells


def stratified_sampling(
    data: list[dict],
    n: int,
    strata: list[dict],
    random_state: int | None = None,
) -> dict:
    """
    Proportional stratified sampling.

    strata: [
        {"column": "Gender", "filters": [{"value": "Male", "pct": 60}, {"value": "Female", "pct": 40}]}
    ]
    """
    rows = _to_records(data)
    population = len(rows)
    rng = random.Random(_seed(random_state))

    if not strata:
        return simple_random_sampling(data, n, random_state)

    cells = _build_strata_cells(strata)
    if not cells:
        return simple_random_sampling(data, n, random_state)

    total_pct = sum(c["pct"] for c in cells)
    parts, breakdown = [], []

    for cell in cells:
        # Filter rows matching this cell
        cell_rows = rows
        for col, val in cell["filters"].items():
            cell_rows = [
                r for r in cell_rows
                if str(r.get(col, "")).strip() == str(val).strip()
            ]

        alloc = max(1, round((cell["pct"] / total_pct) * n)) if total_pct > 0 else 0
        alloc = min(alloc, len(cell_rows))

        sampled_cell = rng.sample(cell_rows, alloc) if alloc > 0 else []
        parts.extend(sampled_cell)

        breakdown.append({
            "label": cell["label"],
            "group_population": len(cell_rows),
            "allocated_pct": round(cell["pct"], 4),
            "sample_count": alloc,
        })

    return {
        "sampled": parts,
        "info": {
            "method": "stratified",
            "population_size": population,
            "sample_size": len(parts),
            "breakdown": breakdown,
        },
    }


# ── 4. Cluster Sampling ───────────────────────────────────────────────────────

def get_cluster_info(data: list[dict], cluster_column: str) -> list[dict]:
    """Return list of {name, count} for every unique value of cluster_column."""
    counts: dict[str, int] = {}
    for row in data:
        key = str(row.get(cluster_column, ""))
        counts[key] = counts.get(key, 0) + 1
    return sorted(
        [{"name": k, "count": v} for k, v in counts.items()],
        key=lambda x: x["count"],
        reverse=True,
    )


def cluster_sampling(
    data: list[dict],
    n: int,
    cluster_column: str,
    mode: str = "auto",
    n_clusters: int | None = None,
    min_cluster_size: int = 0,
    manual_clusters: list[str] | None = None,
    random_state: int | None = None,
) -> dict:
    """
    Cluster sampling with proportional within-cluster allocation.

    mode:
      "auto"   — randomly select eligible clusters (>= min_cluster_size)
      "manual" — use explicit manual_clusters list
    """
    rows = _to_records(data)
    rng = random.Random(_seed(random_state))

    # Build cluster index
    cluster_index: dict[str, list[dict]] = {}
    for row in rows:
        key = str(row.get(cluster_column, ""))
        cluster_index.setdefault(key, []).append(row)

    all_clusters = [
        {"name": k, "count": len(v)}
        for k, v in cluster_index.items()
    ]

    if mode == "manual":
        if not manual_clusters:
            raise ValueError("manual_clusters is required when mode='manual'")
        manual_set = {str(c) for c in manual_clusters}
        eligible = [c for c in all_clusters if c["name"] in manual_set]
    else:
        eligible = [c for c in all_clusters if c["count"] >= int(min_cluster_size)]
        if not eligible:
            raise ValueError(
                f"No clusters meet min_cluster_size={min_cluster_size}. "
                f"Largest cluster has {max(c['count'] for c in all_clusters)} rows."
            )
        if n_clusters and n_clusters < len(eligible):
            eligible = rng.sample(eligible, int(n_clusters))

    if not eligible:
        raise ValueError("No clusters selected.")

    chosen_names = [c["name"] for c in eligible]
    total_eligible_rows = sum(c["count"] for c in eligible)

    parts, breakdown = [], []
    for cluster in eligible:
        cname = cluster["name"]
        ccount = cluster["count"]
        crows = cluster_index[cname]

        alloc = max(1, round((ccount / total_eligible_rows) * n))
        alloc = min(alloc, ccount)

        sampled_c = rng.sample(crows, alloc)
        parts.extend(sampled_c)

        breakdown.append({
            "cluster": cname,
            "cluster_population": ccount,
            "allocated": alloc,
            "sampled_from_cluster": len(sampled_c),
        })

    return {
        "sampled": parts,
        "info": {
            "method": "cluster",
            "population_size": len(rows),
            "sample_size": len(parts),
            "cluster_column": cluster_column,
            "cluster_mode": mode,
            "min_cluster_size": min_cluster_size,
            "clusters_selected": chosen_names,
            "breakdown": breakdown,
        },
    }
