from fastapi import FastAPI
from pydantic import BaseModel, Field
from typing import Optional, Dict, Any, List, Tuple
import fitz  # PyMuPDF
import re
from datetime import datetime
from difflib import SequenceMatcher

from PIL import Image, ImageOps, ImageFilter
import pytesseract
import io
import os

pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe"

app = FastAPI()

SAVE_DEBUG_CROPS = False
DEBUG_CROPS_DIR = r"C:\xampp\htdocs\portal.hickeyplanthire.co.uk\CRUD\suppliers\supplier_invoices\template_previews\debug_crops"


# ---------- request models ----------
class TemplateModel(BaseModel):
    id_supplier: int
    page_no: int = 1
    dpi: int = 300
    boxes: Dict[str, Any]  # {field: {x,y,w,h,psm}} in PDF points (72dpi)


class Req(BaseModel):
    id_supplier_invoice: int
    pdf_path: str
    suppliers: list
    template: Optional[TemplateModel] = None

    # accept multiple aliases from PHP so nothing gets “ignored”
    forced_supplier_id: Optional[int] = None
    force_supplier_id: Optional[int] = None
    force_supplier: Optional[int] = None


def get_forced_supplier(req: Req) -> Optional[int]:
    """
    Normalise all possible keys that PHP might send.
    Priority:
      forced_supplier_id > force_supplier_id > force_supplier
    """
    for v in (req.forced_supplier_id, req.force_supplier_id, req.force_supplier):
        try:
            if v is not None and int(v) > 0:
                return int(v)
        except Exception:
            pass
    return None


# ---------- helpers ----------
def norm_name(s: str) -> str:
    if not s:
        return ""
    s = s.upper()
    s = re.sub(r"[\.\,\(\)\[\]\{\}]", " ", s)
    s = re.sub(r"\b(LIMITED|LTD|PLC|LLP|INC|CO|COMPANY)\b", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s


def similarity(a: str, b: str) -> float:
    return SequenceMatcher(None, a, b).ratio() * 100.0


def find_vat_no(text: str) -> str | None:
    m = re.search(r"\bVAT\s*(?:No|Number)?\s*[:#]?\s*(GB)?\s*([0-9]{9})\b", text, re.IGNORECASE)
    if m:
        return ("GB" + m.group(2)).upper()
    return None


def find_postcode(text: str) -> str | None:
    t = text.upper()
    m = re.search(r"\b([A-Z]{1,2}\d{1,2}[A-Z]?)\s*(\d[A-Z]{2})\b", t)
    if m:
        return (m.group(1) + m.group(2)).replace(" ", "")
    return None


def find_invoice_number(text: str) -> str | None:
    patterns = [
        r"\bInvoice\s*(?:No|Number|#)\s*[:\-]?\s*([A-Z0-9][A-Z0-9\-\/]{2,24})\b",
        r"\bInv\s*(?:No|#)\s*[:\-]?\s*([A-Z0-9][A-Z0-9\-\/]{2,24})\b",
    ]
    for p in patterns:
        m = re.search(p, text, re.IGNORECASE)
        if m:
            candidate = m.group(1).strip()
            if not re.search(r"\d", candidate):
                continue
            if candidate.upper() in {"DATE", "INVOICE", "NUMBER"}:
                continue
            return candidate
    return None


def parse_date_any(text: str) -> str | None:
    """
    Supports:
      - 10/12/2025, 10-12-2025, 10.12.2025 (with optional spaces)
      - 31 Jan 2026
      - 31-Jan-26 / 31-Jan-2026
      - 10122025 (ddmmyyyy)
    Returns YYYY-MM-DD (UK day-first preference)
    """
    if not text:
        return None

    t = text.strip()
    t = t.replace("—", "-").replace("–", "-")
    t = re.sub(r"[\,\.;:\)\]]+$", "", t)
    t = re.sub(r"\s*-\s*", "-", t)
    t = re.sub(r"\s+", " ", t)

    # 0) Compact 8-digit numeric: ddmmyyyy
    m = re.search(r"\b(\d{2})(\d{2})(\d{4})\b", t)
    if m:
        d, mo, y = int(m.group(1)), int(m.group(2)), int(m.group(3))
        try:
            dt = datetime(y, mo, d)
            if 2015 <= dt.year <= 2035:
                return dt.strftime("%Y-%m-%d")
        except:
            pass

    # 1) Numeric: dd/mm/yyyy or dd-mm-yy etc (UK day first)
    m = re.search(r"\b(\d{1,2})\s*[\/\-\.\,]\s*(\d{1,2})\s*[\/\-\.\,]\s*(\d{2,4})\b", t)
    if m:
        d, mo, y = int(m.group(1)), int(m.group(2)), int(m.group(3))
        if y < 100:
            y += 2000
        try:
            dt = datetime(y, mo, d)
            if 2015 <= dt.year <= 2035:
                return dt.strftime("%Y-%m-%d")
        except:
            pass

    # 2) Text month with spaces
    m = re.search(
        r"\b(\d{1,2})\s+"
        r"(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Sept|Oct|Nov|Dec)"
        r"[a-z]*\s+(\d{2,4})\b",
        t,
        re.IGNORECASE
    )
    if m:
        d = int(m.group(1))
        mon = m.group(2)[:3].title()
        y = int(m.group(3))
        if y < 100:
            y += 2000
        try:
            dt = datetime.strptime(f"{d} {mon} {y}", "%d %b %Y")
            if 2015 <= dt.year <= 2035:
                return dt.strftime("%Y-%m-%d")
        except:
            pass

    # 3) Text month with hyphens
    m = re.search(
        r"\b(\d{1,2})\-"
        r"(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Sept|Oct|Nov|Dec)"
        r"[a-z]*\-(\d{2,4})\b",
        t,
        re.IGNORECASE
    )
    if m:
        d = int(m.group(1))
        mon = m.group(2)[:3].title()
        y = int(m.group(3))
        if y < 100:
            y += 2000
        try:
            dt = datetime.strptime(f"{d} {mon} {y}", "%d %b %Y")
            if 2015 <= dt.year <= 2035:
                return dt.strftime("%Y-%m-%d")
        except:
            pass

    return None


def find_invoice_date(text: str) -> str | None:
    label_patterns = [
        r"(Invoice\s*Date\s*[:\-]?\s*)(.+)",
        r"(Tax\s*Point\s*Date\s*[:\-]?\s*)(.+)",
        r"(\bDate\s*[:\-]?\s*)(.+)",
    ]
    lines = text.splitlines()

    for i, line in enumerate(lines[:120]):
        for lp in label_patterns:
            m = re.search(lp, line, re.IGNORECASE)
            if m:
                d = parse_date_any(line)
                if d:
                    return d
                if i + 1 < len(lines):
                    d = parse_date_any(lines[i + 1])
                    if d:
                        return d

    top = "\n".join(lines[:160])
    return parse_date_any(top)


def money_to_decimal(s: str) -> float | None:
    if not s:
        return None
    s = s.replace(",", "").replace("£", "").replace("GBP", "").strip()
    try:
        return float(s)
    except:
        return None


def clean_money_from_text(s: str) -> float | None:
    if not s:
        return None
    t = s.upper().replace("GBP", "").replace("£", "").replace(",", "").strip()
    t = t.replace("O", "0")
    m = re.search(r"\b\d{1,9}(?:\.\d{1,2})?\b", t)
    if not m:
        return None
    try:
        return float(m.group(0))
    except:
        return None


def find_totals(text: str) -> Tuple[float | None, float | None, float | None, str | None]:
    net = vat = gross = None
    currency = "GBP" if ("£" in text or "GBP" in text.upper()) else None

    lines = [l.strip() for l in text.splitlines() if l.strip()]
    for line in lines[::-1][:180]:
        if gross is None and re.search(r"\b(Total\s*Due|Amount\s*Due|Balance\s*Due|Grand\s*Total|Total)\b", line, re.IGNORECASE):
            m = re.findall(r"£?\s*\d{1,3}(?:,\d{3})*(?:\.\d{2})", line)
            if m:
                gross = money_to_decimal(m[-1])

        if vat is None and re.search(r"\bVAT\b", line, re.IGNORECASE) and not re.search(r"VAT\s*(No|Number)", line, re.IGNORECASE):
            m = re.findall(r"£?\s*\d{1,3}(?:,\d{3})*(?:\.\d{2})", line)
            if m:
                vat = money_to_decimal(m[-1])

        if net is None and re.search(r"\b(Net|Subtotal|Sub\s*Total)\b", line, re.IGNORECASE):
            m = re.findall(r"£?\s*\d{1,3}(?:,\d{3})*(?:\.\d{2})", line)
            if m:
                net = money_to_decimal(m[-1])

    if gross is not None and net is not None and vat is None:
        v = round(gross - net, 2)
        if v >= 0:
            vat = v

    return net, vat, gross, currency


def guess_supplier_from_ocr_text(text: str) -> str | None:
    lines = [l.strip() for l in text.splitlines() if l.strip()]
    for line in lines[:30]:
        if re.search(r"\b(INVOICE|DATE|TAX\s+INVOICE|VAT\s+INVOICE|CREDIT\s+NOTE)\b", line, re.IGNORECASE):
            continue
        if 3 <= len(line) <= 80:
            return line[:255]
    return None


def best_supplier_match(suppliers, raw_name, vat_no, postcode):
    raw_norm = norm_name(raw_name or "")

    if vat_no:
        v = vat_no.replace(" ", "").upper()
        for s in suppliers:
            sv = (s.get("supplier_vat_no") or "").replace(" ", "").upper()
            if sv and sv == v:
                return s["id_supplier"], "vat", 99.0

    if postcode:
        pc = postcode.replace(" ", "").upper()
        for s in suppliers:
            sp = (s.get("supplier_postcode_norm") or "").replace(" ", "").upper()
            if sp and sp == pc:
                return s["id_supplier"], "postcode", 92.0

    if raw_norm:
        for s in suppliers:
            if (s.get("supplier_name_norm") or "") == raw_norm:
                return s["id_supplier"], "name_norm", 90.0

    if raw_norm:
        for s in suppliers:
            aliases = s.get("supplier_aliases") or []
            if isinstance(aliases, str) or not isinstance(aliases, list):
                aliases = []
            for a in aliases:
                if norm_name(a) == raw_norm:
                    return s["id_supplier"], "alias", 88.0

    if raw_norm:
        best = (None, None, 0.0)
        for s in suppliers:
            cand = s.get("supplier_name_norm") or norm_name(s.get("supplier_name") or "")
            if not cand:
                continue
            score = similarity(raw_norm, cand)
            if score > best[2]:
                best = (s["id_supplier"], "fuzzy", score)
        if best[0] is not None and best[2] >= 88.0:
            return best

    return None, None, 0.0


def ocr_text_from_pdf(doc: fitz.Document, max_pages: int = 2, dpi: int = 300) -> str:
    texts = []
    pages = min(len(doc), max_pages)
    mat = fitz.Matrix(dpi / 72, dpi / 72)
    for i in range(pages):
        pix = doc[i].get_pixmap(matrix=mat, alpha=False)
        img = Image.open(io.BytesIO(pix.tobytes("png")))
        texts.append(pytesseract.image_to_string(img, lang="eng", config="--psm 6"))
    return "\n".join(texts).strip()


def preprocess_crop(img: Image.Image) -> Image.Image:
    g = ImageOps.grayscale(img)
    g = ImageOps.autocontrast(g)
    g = g.filter(ImageFilter.SHARPEN)
    return g


def ocr_boxes_from_pdf(
    doc: fitz.Document,
    page_no_1based: int,
    boxes: dict,
    dpi: int = 300,
    invoice_id: int = 0
) -> dict:
    page_index = max(0, page_no_1based - 1)
    page_index = min(page_index, len(doc) - 1)
    page = doc[page_index]

    mat = fitz.Matrix(dpi / 72, dpi / 72)
    out = {}

    debug_dir = None
    if SAVE_DEBUG_CROPS:
        debug_dir = os.path.join(DEBUG_CROPS_DIR, f"inv_{invoice_id}_p{page_no_1based}_dpi{dpi}")
        os.makedirs(debug_dir, exist_ok=True)

    for key, b in boxes.items():
        try:
            x, y, w, h = float(b["x"]), float(b["y"]), float(b["w"]), float(b["h"])
        except Exception:
            continue

        psm = int(b.get("psm", 6))
        rect = fitz.Rect(x, y, x + w, y + h)

        pix = page.get_pixmap(matrix=mat, clip=rect, alpha=False)
        img = Image.open(io.BytesIO(pix.tobytes("png")))
        img2 = preprocess_crop(img)

        if SAVE_DEBUG_CROPS and debug_dir:
            img2.save(os.path.join(debug_dir, f"{key}.png"))

        txt = pytesseract.image_to_string(img2, lang="eng", config=f"--psm {psm}")
        out[key] = txt.strip()

        if SAVE_DEBUG_CROPS and debug_dir:
            with open(os.path.join(debug_dir, f"{key}.txt"), "w", encoding="utf-8") as f:
                f.write(out[key])

    return out


# ---------- LINE ITEMS PARSER ----------
def parse_line_items(raw: str) -> List[Dict[str, Any]]:
    if not raw:
        return []

    lines = [l.strip() for l in raw.splitlines() if l.strip()]
    items = []
    line_no = 1

    money_pat = re.compile(r"\b\d{1,9}(?:\.\d{2})\b")
    qty_pat = re.compile(r"\b\d{1,6}(?:\.\d{1,3})?\b")

    UOM_MAP = {
        "EA": "each", "EACH": "each", "UNIT": "each", "UNITS": "each",
        "BOX": "box", "BX": "box",
        "PACK": "pack", "PK": "pack",
        "SET": "set",
        "DAY": "day", "DAYS": "day",
        "WEEK": "week", "WKS": "week", "WEEKS": "week",
        "MONTH": "month", "MNTH": "month", "MONTHS": "month",
        "HR": "hour", "HRS": "hour", "HOUR": "hour", "HOURS": "hour",
        "M": "m", "METRE": "m", "METER": "m", "METRES": "m",
        "L": "l", "LTR": "l", "LITRE": "l", "LITRES": "l",
        "KG": "kg", "KGS": "kg",
        "TON": "ton", "TONNE": "ton", "TONNES": "ton",
        "JOB": "job", "ITEM": "item"
    }

    def norm_uom(token: str) -> Optional[str]:
        t = re.sub(r"[^A-Z]", "", token.upper())
        if not t:
            return None
        return UOM_MAP.get(t)

    for ln in lines:
        if re.search(r"\b(description|details|qty|quantity|unit|price|total|amount|vat)\b", ln, re.IGNORECASE):
            continue

        s = re.sub(r"\s{2,}", " ", ln).strip()
        s_money_norm = s.replace(",", "").replace("O", "0")

        monies = money_pat.findall(s_money_norm)
        unit_price = line_total = None

        if len(monies) >= 2:
            unit_price = float(monies[-2])
            line_total = float(monies[-1])
        elif len(monies) == 1:
            line_total = float(monies[-1])

        s_wo_money = money_pat.sub("", s_money_norm).strip()

        qty = None
        qty_candidates = qty_pat.findall(s_wo_money)
        if qty_candidates:
            try:
                qty = float(qty_candidates[-1])
            except:
                qty = None

        product_uom = None
        product_uom_qty = None

        m_uom = re.search(r"\b(\d{1,6}(?:\.\d{1,3})?)\s*([A-Za-z]{1,8})\b", s_wo_money)
        if m_uom:
            q_str = m_uom.group(1)
            u_str = m_uom.group(2)
            u = norm_uom(u_str)
            if u:
                try:
                    product_uom_qty = float(q_str)
                except:
                    product_uom_qty = None
                product_uom = u

        if not product_uom:
            for tok in re.split(r"\s+", s_wo_money):
                u = norm_uom(tok)
                if u:
                    product_uom = u
                    break

        product_code = None
        for tok in re.split(r"\s+", s_wo_money):
            t2 = re.sub(r"[^A-Za-z0-9\-\/]", "", tok)
            if len(t2) >= 3 and re.search(r"[A-Za-z]", t2) and re.search(r"\d", t2):
                product_code = t2[:80]
                break

        desc = s_wo_money
        if product_code:
            desc = desc.replace(product_code, " ").strip()
        desc = re.sub(r"\s{2,}", " ", desc).strip()

        if not desc:
            continue

        items.append({
            "line_no": line_no,
            "description": desc[:1000],
            "product_code": product_code,
            "qty": qty,
            "product_uom_qty": product_uom_qty,
            "product_uom": product_uom,
            "unit_price": unit_price,
            "line_total": line_total,
            "is_stock_item": 0
        })

        line_no += 1
        if line_no > 200:
            break

    return items


# ---------- endpoint ----------
@app.post("/extract")
def extract(req: Req):
    doc = fitz.open(req.pdf_path)
    try:
        forced = get_forced_supplier(req)

        # baseline OCR (raw_text + fallback fields)
        text = ocr_text_from_pdf(doc, max_pages=2, dpi=300)
        method = "ocr_forced"

        supplier_raw = guess_supplier_from_ocr_text(text)
        vat_no_found = find_vat_no(text)
        postcode_found = find_postcode(text)
        inv_no = find_invoice_number(text)
        inv_date = find_invoice_date(text)
        net, vat, gross, currency = find_totals(text)

        # Supplier match (may be overridden)
        id_supplier, match_method, match_score = best_supplier_match(
            req.suppliers, supplier_raw, vat_no_found, postcode_found
        )

        # FORCE supplier (from dropdown)
        if forced:
            id_supplier = int(forced)
            match_method = "forced"
            match_score = 100.0

        template_used = False
        box_results: Dict[str, str] = {}

        purchase_order_no = None
        contract_no = None
        items: List[Dict[str, Any]] = []

        # Use template boxes if available and supplier matches template supplier.
        # With forced supplier + template coming from PHP, this is stable.
        if req.template and id_supplier and req.template.id_supplier == id_supplier and req.template.boxes:
            try:
                box_results = ocr_boxes_from_pdf(
                    doc=doc,
                    page_no_1based=req.template.page_no,
                    boxes=req.template.boxes,
                    dpi=req.template.dpi,
                    invoice_id=req.id_supplier_invoice
                )
                template_used = True

                inv_no_box = (box_results.get("invoice_number") or "").strip()
                if inv_no_box:
                    inv_no = inv_no_box

                inv_date_box = parse_date_any(box_results.get("invoice_date", ""))
                if inv_date_box:
                    inv_date = inv_date_box

                net_box = clean_money_from_text(box_results.get("net_total", ""))
                vat_box = clean_money_from_text(box_results.get("vat_total", ""))
                gross_box = clean_money_from_text(box_results.get("gross_total", ""))

                net = net_box if net_box is not None else net
                vat = vat_box if vat_box is not None else vat
                gross = gross_box if gross_box is not None else gross

                vat_box_txt = (box_results.get("vat_no") or "").strip()
                if vat_box_txt:
                    vat_no_found = vat_box_txt

                pc_box_txt = (box_results.get("postcode") or "").strip()
                if pc_box_txt:
                    postcode_found = pc_box_txt

                purchase_order_no = (box_results.get("purchase_order_no") or "").strip() or None
                contract_no = (box_results.get("contract_no") or "").strip() or None

                raw_items = (box_results.get("line_items_area") or "").strip()
                if raw_items:
                    items = parse_line_items(raw_items)

            except Exception:
                template_used = False
                box_results = {}
                items = []

        # Confidence heuristic
        conf = 0.0
        conf += 0.25 if supplier_raw else 0.0
        conf += 0.35 if inv_no else 0.0
        conf += 0.30 if inv_date else 0.0
        conf += 0.10 if gross is not None else 0.0
        if id_supplier:
            conf += 0.10
        if purchase_order_no or contract_no:
            conf += 0.05
        if items:
            conf += 0.05
        conf = min(conf, 0.999)

        needs_review = conf < 0.75

        return {
            "extraction_method": "ocr_template" if template_used else method,
            "raw_text": text,

            "supplier_name_raw": supplier_raw,
            "vat_no": vat_no_found,
            "postcode": postcode_found,

            "invoice_number": inv_no,
            "invoice_date": inv_date,
            "net_total": net,
            "vat_total": vat,
            "gross_total": gross,
            "currency": currency or "GBP",

            "id_supplier": id_supplier,
            "supplier_match_method": match_method,
            "supplier_match_score": match_score,

            "purchase_order_no": purchase_order_no,
            "contract_no": contract_no,

            "confidence": round(conf, 3),
            "needs_review": 1 if needs_review else 0,

            "template_used": template_used,
            "box_results": box_results,

            "items": items
        }
    finally:
        doc.close()
