import requests
import json
import re
import argparse
import sys

# ─────────────────────────────────────────────
# CONFIG
# ─────────────────────────────────────────────
AUTH_URL = ""
EMAIL = ""
PASSWORD = ""
OUTPUT_FILE = "schema_output.json"

COMMON_RECORD_KEYS = (
    "data",
    "results",
    "items",
    "records",
    "rows",
    "docs",
    "entries",
    "payload",
    "result",
)

COMMON_ENVELOPE_KEYS = (
    "success",
    "error",
    "errors",
    "detail",
    "meta",
)


# ─────────────────────────────────────────────


def fetch_data(auth_url, email, password, auth_mode="credentials", token="", request_method="POST", timeout_s=60):
    """Fetch data from API using either credentials or bearer token auth."""
    print("[1/2] Fetching data...")
    method = (request_method or "POST").upper()
    auth_mode = (auth_mode or "credentials").lower()

    session = requests.Session()
    session.trust_env = False

    headers = {}
    payload = None

    if auth_mode == "token":
        if token:
            headers["Authorization"] = f"Bearer {token}"
    else:
        payload = {"email": email, "password": password}

    if method == "GET":
        res = session.get(auth_url, headers=headers, timeout=timeout_s)
    else:
        res = session.post(auth_url, json=payload, headers=headers, timeout=timeout_s)

    res.raise_for_status()
    raw = res.json()

    records = normalize_records(raw)

    if records:
        print(f"   [OK] {len(records)} record(s) fetched.")
        return records

    print("[WARN] Unexpected response shape:")
    print(json.dumps(raw, indent=2)[:500])
    return []


def normalize_records(raw):
    """Extract a flat list of records from common API response shapes."""
    nested_records = find_nested_record_list(raw)
    if nested_records:
        return nested_records

    if isinstance(raw, list):
        return [item for item in raw if isinstance(item, dict)]

    if not isinstance(raw, dict):
        return []

    for key in COMMON_RECORD_KEYS:
        for actual_key, value in raw.items():
            if str(actual_key).lower() == key:
                nested = normalize_records(value)
                if nested:
                    return nested

    if any(str(key).lower() in COMMON_ENVELOPE_KEYS for key in raw.keys()):
        return []

    # If the response is a single object, treat it as one record.
    # This keeps object-style APIs usable instead of returning an empty schema.
    scalar_values = [
        value for value in raw.values()
        if not isinstance(value, (list, dict))
    ]

    if scalar_values or raw:
        return [raw]

    return []


def find_nested_record_list(value):
    """Recursively locate the first list that contains record dictionaries."""
    if isinstance(value, list):
        dict_items = [item for item in value if isinstance(item, dict)]
        if dict_items:
            return dict_items

        for item in value:
            nested = find_nested_record_list(item)
            if nested:
                return nested

    elif isinstance(value, dict):
        for nested_value in value.values():
            nested = find_nested_record_list(nested_value)
            if nested:
                return nested

    return []


def infer_datatype(value):
    """Infer datatype and keep date values as formatted text."""
    if value is None:
        return "varchar", ""
    if isinstance(value, bool):
        return "boolean", ""
    if isinstance(value, int):
        return "int", ""
    if isinstance(value, float):
        return "decimal", ""
    if isinstance(value, list):
        return "array", ""
    if isinstance(value, dict):
        return "json", ""

    s = str(value).strip()

    # Date/time values stay varchar so imports can preserve the project's date format.
    # Comprehensive patterns for dates, times, and datetimes
    date_patterns = [
        r"^\d{4}-\d{2}-\d{2}$",  # 2024-01-01
        r"^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}",  # 2024-01-01 12:00:00 or 2024-01-01T12:00:00
        r"^\d{4}/\d{2}/\d{2}",  # 2024/01/01
        r"^\d{2}-\d{2}-\d{4}",  # 01-12-2024 or 12-01-2024
        r"^\d{2}/\d{2}/\d{4}",  # 01/12/2024 or 12/01/2024
        r"^\d{1,2}-\d{1,2}-\d{4}",  # 1-1-2024
        r"^\d{1,2}/\d{1,2}/\d{4}",  # 1/1/2024
        r"^\d{2}:\d{2}:\d{2}",  # 12:00:00
        r"^\d{1,2}:\d{2}(:\d{2})?(\s?[AP]M)?$",  # 12:00 or 12:00:00 or 12:00 PM
        r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+Z?$",  # ISO 8601 with milliseconds
        r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}[+-]\d{2}:\d{2}$",  # ISO 8601 with timezone
    ]

    for pattern in date_patterns:
        if re.match(pattern, s, re.IGNORECASE):
            return "varchar", "date"

    return "varchar", ""


def is_lookup_field(key, seen_values):
    """True if all non-null values look like small integer codes."""
    non_null = [v for v in seen_values if v is not None and v != ""]
    if not non_null:
        return False
    return all(isinstance(v, str) and re.match(r"^\d{1,2}$", v) for v in non_null)


def make_field(field_id, key, seen_values):
    first_val = next((v for v in seen_values if v is not None), None)
    nullable = any(v is None or v == "" for v in seen_values)
    datatype, semantic_type = infer_datatype(first_val)

    return {
        "id": field_id,
        "name": key,
        "multiple_options": False,
        "datatype": datatype,
        "type": semantic_type,
        "nullable": nullable,
        "predefined": "",
        "length": 255 if datatype == "varchar" else None,
        "is_lookup": False,
        "lookup_reference": ""
    }


def build_schema(records):
    print("[2/2] Building schema...")

    flat_values = {}
    child_values = {}
    record_count = 0
    child_row_counts = {}

    for record in records:
        current_flat = {}

        for key, value in record.items():
            if isinstance(value, list) and value and isinstance(value[0], dict):
                child_values.setdefault(key, {})
                child_row_counts.setdefault(key, 0)

                for row in value:
                    for ck in row.keys():
                        if ck not in child_values[key]:
                            child_values[key][ck] = [None] * child_row_counts[key]

                    for ck, cv in row.items():
                        child_values[key][ck].append(cv)

                    for ck in child_values[key].keys():
                        if ck not in row:
                            child_values[key][ck].append(None)

                    child_row_counts[key] += 1
            else:
                current_flat[key] = value

        for key in current_flat.keys():
            if key not in flat_values:
                flat_values[key] = [None] * record_count

        for key in flat_values.keys():
            flat_values[key].append(current_flat.get(key))

        record_count += 1

    # ── Submissions table ──
    submissions = []

    for key, values in flat_values.items():

        field = make_field(key, key, values)
        submissions.append(field)

    # ── Child tables ──
    child_tables = {}

    for tname, fields_map in child_values.items():
        rows = []

        for ck, cvalues in fields_map.items():
            field = make_field(ck, ck, cvalues)
            rows.append(field)

        child_tables[tname] = rows

    return {
        "Schema": {
            "Tables": {"submissions": submissions},
            "Lookup_tables": [],
            "Child_tables": child_tables
        }
    }


def parse_args():
    parser = argparse.ArgumentParser(description="Build schema JSON from API response")
    parser.add_argument("--url", default=AUTH_URL, help="API URL")
    parser.add_argument("--email", default=EMAIL)
    parser.add_argument("--password", default=PASSWORD)
    parser.add_argument("--auth_mode", default="credentials", choices=["credentials", "token"])
    parser.add_argument("--token", default="")
    parser.add_argument("--request_method", default="POST", choices=["GET", "POST"])
    parser.add_argument("--timeout_s", default=60, type=int)
    parser.add_argument("--output", default=OUTPUT_FILE)
    return parser.parse_args()


def main():
    try:
        args = parse_args()

        if not args.url.strip():
            print("\n[ERROR] API URL is required.")
            sys.exit(1)

        if args.auth_mode == "credentials":
            if not args.email.strip() or not args.password.strip():
                print("\n[ERROR] Email and password required.")
                sys.exit(1)

        if args.auth_mode == "token":
            if not args.token.strip():
                print("\n[ERROR] Token required.")
                sys.exit(1)

        records = fetch_data(
            auth_url=args.url,
            email=args.email,
            password=args.password,
            auth_mode=args.auth_mode,
            token=args.token,
            request_method=args.request_method,
            timeout_s=args.timeout_s,
        )

        if not records:
            print("\n[ERROR] No data received.")
            sys.exit(2)

        schema = build_schema(records)

        if not schema or "Schema" not in schema:
            print("\n[ERROR] Failed to build schema.")
            sys.exit(3)

        with open(args.output, "w", encoding="utf-8") as f:
            json.dump(schema, f, indent=2, ensure_ascii=False)

        print(f"\n[OK] Schema saved to: {args.output}")
        print(f"  • Submission fields : {len(schema['Schema']['Tables']['submissions'])}")
        print(f"  • Child tables      : {len(schema['Schema']['Child_tables'])}")
        print(f"  • Lookup tables     : {len(schema['Schema']['Lookup_tables'])}")

    except requests.exceptions.HTTPError as e:
        print(f"\n[ERROR] HTTP Error: {e.response.status_code} - {e.response.text}")
        sys.exit(4)
    except requests.exceptions.ConnectionError as e:
        print(f"\n[ERROR] Connection error: {type(e).__name__}: {e}")
        sys.exit(5)
    except requests.exceptions.RequestException as e:
        print(f"\n[ERROR] Request failed: {type(e).__name__}: {e}")
        sys.exit(5)
    except Exception as e:
        import traceback
        print(f"\n[ERROR] Unexpected error: {e}")
        traceback.print_exc()
        sys.exit(6)


if __name__ == "__main__":
    main()
 
