"""
routes/hierarchical_routes.py
------------------------------
Endpoints for multi-level (hierarchical) sampling.

Design:
  - Caller sends the FULL dataset on every request (stateless/multitenant safe).
  - No data is retained between calls.
  - mask_rules narrow the effective population but data is never deleted.
"""
from fastapi import APIRouter
from pydantic import BaseModel, Field

from functions.hierarchical import (
    options_for_level,
    preview_last_level,
    hierarchical_sample,
    count_effective_rows,
)

router = APIRouter()


# ── Shared sub-models ─────────────────────────────────────────────────────────

class MaskRule(BaseModel):
    col: str
    op: str = Field(..., description="Operator: ==, !=, in, not in, contains, not contains, >, >=, <, <=")
    value: str | list[str] | None = None
    values: list[str] | None = None  # alias for value when it's a list


# ── 1. Level Options ──────────────────────────────────────────────────────────

class OptionsRequest(BaseModel):
    data: list[dict] = Field(..., description="Full dataset rows")
    depth_levels: list[str] = Field(..., description="Ordered list of hierarchy columns, e.g. ['PROVINCE', 'DISTRICT', 'TEHSIL']")
    level: str = Field(..., description="The level to fetch options for")
    selected: dict[str, list[str]] = Field(default_factory=dict, description="Currently selected values per level")
    mask_rules: list[MaskRule] = Field(default_factory=list, description="Conditional filter rules")
    min_count: int = Field(0, ge=0, description="Minimum rows required to show a group option (applied on last level only)")


@router.post("/options")
def get_options(payload: OptionsRequest):
    """
    Return selectable options for a given level in the hierarchy.
    Parent level must have selections before child options are returned.
    Returns value, count, path, and key for each group.
    """
    opts = options_for_level(
        data=payload.data,
        depth_levels=payload.depth_levels,
        level=payload.level,
        selected=payload.selected,
        mask_rules=[r.model_dump() for r in payload.mask_rules],
        min_count=payload.min_count,
    )
    return {"options": opts}


# ── 2. Preview Last Level ─────────────────────────────────────────────────────

class PreviewRequest(BaseModel):
    data: list[dict] = Field(..., description="Full dataset rows")
    depth_levels: list[str] = Field(..., description="Ordered hierarchy columns")
    selected: dict[str, list[str]] = Field(default_factory=dict)
    mask_rules: list[MaskRule] = Field(default_factory=list)
    min_count: int = Field(0, ge=0)


@router.post("/preview")
def preview(payload: PreviewRequest):
    """
    Preview counts for every group in the last hierarchy level.
    Each item includes an 'eligible' flag (count >= min_count).
    """
    out = preview_last_level(
        data=payload.data,
        depth_levels=payload.depth_levels,
        selected=payload.selected,
        mask_rules=[r.model_dump() for r in payload.mask_rules],
        min_count=payload.min_count,
    )
    return {"preview": out}


# ── 3. Hierarchical Sample ────────────────────────────────────────────────────

class SampleRequest(BaseModel):
    data: list[dict] = Field(..., description="Full dataset rows")
    depth_levels: list[str] = Field(..., description="Ordered hierarchy columns")
    selected: dict[str, list[str]] = Field(default_factory=dict, description="Selected values per level")
    mask_rules: list[MaskRule] = Field(default_factory=list)
    min_count: int = Field(0, ge=0, description="Groups with fewer rows than this are skipped")
    sample_count: int = Field(..., gt=0, description="Default number of rows to draw per group")
    quota: dict[str, int] = Field(
        default_factory=dict,
        description="Per-group overrides: {full_path_key: n}. Key format: 'val1||val2||val3'",
    )
    random_state: int = Field(42, description="Random seed")


@router.post("/sample")
def hierarchical_sample_endpoint(payload: SampleRequest):
    """
    Run hierarchical sampling across all selected last-level groups.

    - mask_rules narrow the effective population (all rows are still accepted; masking happens here)
    - selected filters apply at each parent level
    - quota overrides sample_count for specific groups (key = full path joined by '||')
    - Groups below min_count are skipped
    - Returns sampled rows + summary metadata
    """
    sampled = hierarchical_sample(
        data=payload.data,
        depth_levels=payload.depth_levels,
        selected=payload.selected,
        mask_rules=[r.model_dump() for r in payload.mask_rules],
        min_count=payload.min_count,
        sample_count=payload.sample_count,
        quota=payload.quota,
        random_state=payload.random_state,
    )
    return {
        "sampled": sampled,
        "info": {
            "method": "hierarchical",
            "depth_levels": payload.depth_levels,
            "sample_size": len(sampled),
            "sample_count_default": payload.sample_count,
            "min_count": payload.min_count,
        },
    }


# ── 4. Effective Row Count ────────────────────────────────────────────────────

class EffectiveRowsRequest(BaseModel):
    data: list[dict] = Field(..., description="Full dataset rows")
    mask_rules: list[MaskRule] = Field(default_factory=list)


@router.post("/effective-rows")
def effective_rows(payload: EffectiveRowsRequest):
    """
    Count how many rows pass the mask rules vs. total rows.
    Useful for showing 'X of Y rows are eligible' in the UI.
    """
    return count_effective_rows(
        data=payload.data,
        mask_rules=[r.model_dump() for r in payload.mask_rules],
    )
