"""
routes/sampling_routes.py
--------------------------
Endpoints for the four core sampling techniques.

Every endpoint:
  - Receives data as JSON (list[dict]) in the request body
  - Performs calculations purely on that data
  - Returns sampled rows + sampling metadata
  - Shares NO state with any other request or endpoint
"""
from fastapi import APIRouter
from pydantic import BaseModel, Field

from functions.sampling import (
    simple_random_sampling,
    systematic_sampling,
    stratified_sampling,
    cluster_sampling,
    get_cluster_info,
)

router = APIRouter()


# ── Shared base ───────────────────────────────────────────────────────────────

class BaseRequest(BaseModel):
    data: list[dict] = Field(..., description="Dataset rows as a list of objects")
    n: int = Field(..., gt=0, description="Desired sample size")
    random_state: int | None = Field(None, description="Random seed for reproducibility")


# ── 1. Simple Random Sampling ─────────────────────────────────────────────────

@router.post("/simple-random")
def simple_random(payload: BaseRequest):
    """
    Draw n rows at random without replacement from the provided dataset.
    """
    return simple_random_sampling(
        data=payload.data,
        n=payload.n,
        random_state=payload.random_state,
    )


# ── 2. Systematic (Interval-Based) Sampling ───────────────────────────────────

@router.post("/systematic")
def systematic(payload: BaseRequest):
    """
    Interval/systematic sampling.
    Calculates interval k = N // n, picks a random start, then steps by k.
    """
    return systematic_sampling(
        data=payload.data,
        n=payload.n,
        random_state=payload.random_state,
    )


# ── 3. Stratified Sampling ────────────────────────────────────────────────────

class StratumFilter(BaseModel):
    value: str
    pct: float = Field(..., ge=0, le=100)


class Stratum(BaseModel):
    column: str
    filters: list[StratumFilter]


class StratifiedRequest(BaseRequest):
    strata: list[Stratum] = Field(
        ...,
        description=(
            "List of strata definitions. Each stratum specifies a column and its "
            "value-to-percentage mapping. Multiple columns produce a cartesian product."
        ),
    )


@router.post("/stratified")
def stratified(payload: StratifiedRequest):
    """
    Proportional stratified sampling.

    Provide one or more columns with their value-percentage breakdown.
    Rows in each stratum cell are sampled in proportion to the specified percentages.
    Multiple columns produce nested (cross-classified) strata via cartesian product.
    """
    strata_dicts = [s.model_dump() for s in payload.strata]
    return stratified_sampling(
        data=payload.data,
        n=payload.n,
        strata=strata_dicts,
        random_state=payload.random_state,
    )


# ── 4. Cluster Sampling ───────────────────────────────────────────────────────

class ClusterRequest(BaseRequest):
    cluster_column: str = Field(..., description="Column whose unique values define clusters")
    mode: str = Field("auto", description="'auto' (random selection) or 'manual' (explicit list)")
    n_clusters: int | None = Field(None, description="How many clusters to select in auto mode")
    min_cluster_size: int = Field(0, ge=0, description="Minimum rows a cluster must have to be eligible (auto mode)")
    manual_clusters: list[str] | None = Field(None, description="Explicit cluster names for manual mode")


@router.post("/cluster")
def cluster(payload: ClusterRequest):
    """
    Cluster sampling with proportional within-cluster allocation.

    auto mode: randomly pick n_clusters from clusters >= min_cluster_size, then sample proportionally.
    manual mode: use the explicit manual_clusters list, then sample proportionally.
    """
    return cluster_sampling(
        data=payload.data,
        n=payload.n,
        cluster_column=payload.cluster_column,
        mode=payload.mode,
        n_clusters=payload.n_clusters,
        min_cluster_size=payload.min_cluster_size,
        manual_clusters=payload.manual_clusters,
        random_state=payload.random_state,
    )


# ── Utility: Cluster Info ─────────────────────────────────────────────────────

class ClusterInfoRequest(BaseModel):
    data: list[dict] = Field(..., description="Dataset rows")
    cluster_column: str = Field(..., description="Column to group by")
    min_size: int = Field(0, ge=0, description="Filter eligible clusters by minimum size")


@router.post("/cluster-info")
def cluster_info(payload: ClusterInfoRequest):
    """
    Return all unique cluster values with their row counts.
    Useful for previewing clusters before running cluster sampling.
    """
    all_clusters = get_cluster_info(payload.data, payload.cluster_column)
    eligible = [c for c in all_clusters if c["count"] >= payload.min_size] if payload.min_size > 0 else all_clusters
    return {
        "clusters": all_clusters,
        "eligible": eligible,
        "total_clusters": len(all_clusters),
        "eligible_count": len(eligible),
    }
