from collections import defaultdict
from datetime import datetime
import re

import pytz

from model.product_catalog_model import ProductCatalogModel
from model.user_model import UserModel

from MongoDBConnection import db

orders_collection = db["orders"]
products_collection = db["products"]
VN_TZ = pytz.timezone("Asia/Ho_Chi_Minh")


class ImportModel:
    @staticmethod
    def _parse_address(address: str):
        parts = [part.strip() for part in (address or "").split(",") if part.strip()]

        street = parts[0] if len(parts) > 0 else ""
        city = parts[1] if len(parts) > 1 else ""
        state = parts[2] if len(parts) > 2 else ""
        zip_code = parts[3] if len(parts) > 3 else ""
        country = parts[4] if len(parts) > 4 else ""

        return {
            "street": street,
            "city": city,
            "state": state,
            "country": country,
            "zipCode": zip_code,
        }

    @staticmethod
    def _normalize_compare_text(value):
        text = str(value or "")
        text = re.sub(r"\s+", " ", text).strip()
        return text.upper()

    @staticmethod
    def _normalize_variant_text(value):
        text = str(value or "")
        text = re.sub(r"\s+", " ", text).strip()
        return text.lower()

    @staticmethod
    def _build_group_key(row: dict):
        customer_name = ImportModel._normalize_compare_text(row.get("customer", ""))
        parsed_address = ImportModel._parse_address(row.get("address", ""))

        street = ImportModel._normalize_compare_text(parsed_address.get("street", ""))
        city = ImportModel._normalize_compare_text(parsed_address.get("city", ""))
        state = ImportModel._normalize_compare_text(parsed_address.get("state", ""))
        country = ImportModel._normalize_compare_text(parsed_address.get("country", ""))
        zip_code = ImportModel._normalize_compare_text(parsed_address.get("zipCode", ""))

        return "||".join([
            customer_name,
            street,
            city,
            state,
            country,
            zip_code,
        ])

    @staticmethod
    def _to_float(value, default=0):
        try:
            if value is None or value == "":
                return float(default)
            return float(str(value).replace(",", "").strip())
        except Exception:
            return float(default)

    @staticmethod
    def _to_int(value, default=1):
        try:
            if value is None or value == "":
                return int(default)
            return int(float(str(value).replace(",", "").strip()))
        except Exception:
            return int(default)

    @staticmethod
    def _build_uploader_info(id_khach_hang: str = None):
        if not id_khach_hang:
            return {
                "uploadedByName": "",
                "uploadedByEmail": "",
            }

        user = UserModel.get_user_by_id(id_khach_hang)
        if not user:
            return {
                "uploadedByName": "",
                "uploadedByEmail": "",
            }

        full_name = f"{user.get('firstName', '')} {user.get('lastName', '')}".strip()

        return {
            "uploadedByName": full_name or user.get("email", ""),
            "uploadedByEmail": user.get("email", ""),
        }

    @staticmethod
    def _validate_basic_row(row: dict):
        errors = []

        if not str(row.get("customer") or "").strip():
            errors.append("Thiếu tên khách hàng")

        if not str(row.get("address") or "").strip():
            errors.append("Thiếu địa chỉ")

        if not str(row.get("productId") or "").strip():
            errors.append("Thiếu mã sản phẩm")

        quantity = ImportModel._to_int(row.get("quantity"), 0)
        if quantity <= 0:
            errors.append("Số lượng phải lớn hơn 0")

        return errors

    @staticmethod
    def _get_product_display_name(product_doc: dict, row: dict, product_catalog: dict):
        product_id = str(row.get("productId") or "").strip()
        return str(
            (product_doc or {}).get("title")
            or row.get("productName")
            or product_catalog.get(product_id)
            or product_id
        ).strip()

    @staticmethod
    def _is_us_country(country_value: str):
        normalized = ImportModel._normalize_variant_text(country_value)
        us_values = {
            "us",
            "usa",
            "u.s.a",
            "united states",
            "united states of america",
            "my",
            "hoa ky",
            "nuoc my",
        }
        return normalized in us_values

    @staticmethod
    def _detect_shipping_region(row: dict):
        parsed_address = ImportModel._parse_address(row.get("address", ""))
        country = parsed_address.get("country", "")
        return "US" if ImportModel._is_us_country(country) else "International"

    @staticmethod
    def _normalize_shipping_rate(rate: dict):
        return {
            "region": str(rate.get("region") or "").strip(),
            "transitTime": str(rate.get("transitTime") or "").strip(),
            "cost": ImportModel._to_float(rate.get("cost"), 0),
            "importTax": ImportModel._to_float(rate.get("importTax"), 0),
            "importTaxNote": str(rate.get("importTaxNote") or "").strip(),
        }

    @staticmethod
    def _pick_shipping_rate(product_doc: dict, row: dict):
        shipping_rates = product_doc.get("shippingRates") or []
        if not shipping_rates:
            return {
                "region": ImportModel._detect_shipping_region(row),
                "transitTime": "",
                "cost": 0.0,
                "importTax": 0.0,
                "importTaxNote": "",
            }

        target_region = ImportModel._detect_shipping_region(row)
        normalized_target = ImportModel._normalize_variant_text(target_region)

        for rate in shipping_rates:
            normalized_rate = ImportModel._normalize_variant_text(rate.get("region"))
            if normalized_rate == normalized_target:
                return ImportModel._normalize_shipping_rate(rate)

        if target_region == "US":
            for rate in shipping_rates:
                normalized_rate = ImportModel._normalize_variant_text(rate.get("region"))
                if normalized_rate in {"us", "united states"}:
                    return ImportModel._normalize_shipping_rate(rate)
        else:
            for rate in shipping_rates:
                normalized_rate = ImportModel._normalize_variant_text(rate.get("region"))
                if normalized_rate == "international":
                    return ImportModel._normalize_shipping_rate(rate)

        return ImportModel._normalize_shipping_rate(shipping_rates[0])

    @staticmethod
    def _find_matching_variant(variants: list, requested_style: str, requested_size: str, requested_color: str = ""):
        normalized_style = ImportModel._normalize_variant_text(requested_style)
        normalized_size = ImportModel._normalize_variant_text(requested_size)
        normalized_color = ImportModel._normalize_variant_text(requested_color)

        exact_color_matches = []
        fallback_matches = []

        for variant in variants:
            variant_style = ImportModel._normalize_variant_text(variant.get("style"))
            variant_size = ImportModel._normalize_variant_text(variant.get("size"))
            variant_color = ImportModel._normalize_variant_text(variant.get("color"))

            if variant_style == normalized_style and variant_size == normalized_size:
                fallback_matches.append(variant)

                if normalized_color and variant_color == normalized_color:
                    exact_color_matches.append(variant)

        if exact_color_matches:
            return exact_color_matches[0]

        if fallback_matches:
            return fallback_matches[0]

        return None

    @staticmethod
    def _validate_and_enrich_row(row: dict, product_catalog: dict):
        product_id = str(row.get("productId") or "").strip()
        product_doc = products_collection.find_one({"productId": product_id})

        product_name = ImportModel._get_product_display_name(product_doc, row, product_catalog)

        if not product_doc:
            return None, f'Đơn hàng "{product_name}" hiện tại không còn mặt hàng này.'

        requested_style = str(row.get("style") or "").strip()
        requested_size = str(row.get("size") or "").strip()
        requested_color = str(row.get("color") or "").strip()

        if not requested_style:
            return None, f'Đơn hàng "{product_name}" đang thiếu style.'

        if not requested_size:
            return None, f'Đơn hàng "{product_name}" đang thiếu size.'

        variants = product_doc.get("variants") or []
        if not variants:
            return None, f'Đơn hàng "{product_name}" hiện tại không còn mặt hàng này.'

        available_styles = {
            ImportModel._normalize_variant_text(v.get("style"))
            for v in variants
            if str(v.get("style") or "").strip()
        }

        available_sizes = {
            ImportModel._normalize_variant_text(v.get("size"))
            for v in variants
            if str(v.get("size") or "").strip()
        }

        normalized_requested_style = ImportModel._normalize_variant_text(requested_style)
        normalized_requested_size = ImportModel._normalize_variant_text(requested_size)

        if normalized_requested_style not in available_styles:
            return None, f'Đơn hàng "{product_name}" hiện tại đã hết style "{requested_style}".'

        if normalized_requested_size not in available_sizes:
            return None, f'Đơn hàng "{product_name}" hiện tại đã hết size "{requested_size}".'

        matched_variant = ImportModel._find_matching_variant(
            variants=variants,
            requested_style=requested_style,
            requested_size=requested_size,
            requested_color=requested_color,
        )

        if not matched_variant:
            return None, f'Đơn hàng "{product_name}" hiện tại không còn mặt hàng này.'

        shipping_rate = ImportModel._pick_shipping_rate(product_doc, row)
        quantity = ImportModel._to_int(row.get("quantity"), 1)
        variant_price = ImportModel._to_float(
            matched_variant.get("price"),
            product_doc.get("price") or 0
        )

        import_tax_per_item = ImportModel._to_float(shipping_rate.get("importTax"), 0)
        line_tax_total = quantity * import_tax_per_item
        shipping_cost_once = ImportModel._to_float(shipping_rate.get("cost"), 0)

        enriched_row = {
            **row,
            "productId": product_id,
            "productName": product_name,
            "style": requested_style,
            "size": requested_size,
            "color": requested_color or str(matched_variant.get("color") or "").strip(),
            "quantity": quantity,
            "price": variant_price,
            "tax": line_tax_total,
            "_shipping_cost_once": shipping_cost_once,
            "_shipping_region": shipping_rate.get("region", ""),
            "_shipping_tax_note": shipping_rate.get("importTaxNote", ""),
        }

        return enriched_row, None

    @staticmethod
    def _parse_item(row: dict, product_catalog: dict, shipping_override: float = 0.0):
        item_text = (row.get("items") or "").strip()

        qty_match = re.search(r"\(Qty:\s*(\d+)\)", item_text, re.IGNORECASE)
        parsed_qty = int(qty_match.group(1)) if qty_match else 1

        product_id = str(row.get("productId") or "").strip()
        product_name = str(
            row.get("productName") or product_catalog.get(product_id) or ""
        ).strip()

        return {
            "style": str(row.get("style") or "").strip(),
            "productId": product_id,
            "productName": product_name,
            "quantity": ImportModel._to_int(row.get("quantity"), parsed_qty),
            "price": ImportModel._to_float(row.get("price"), 0),
            "color": str(row.get("color") or "").strip(),
            "size": str(row.get("size") or "").strip(),
            "design": row.get("design") or "",
            "shipping": ImportModel._to_float(shipping_override, 0),
            "tax": ImportModel._to_float(row.get("tax"), 0),
        }

    @staticmethod
    def _get_order_shipping_once(grouped_rows: list):
        shipping_candidates = [
            ImportModel._to_float(row.get("_shipping_cost_once"), 0)
            for row in grouped_rows
        ]
        if not shipping_candidates:
            return 0.0
        return max(shipping_candidates)

    @staticmethod
    def _build_order_doc(
        grouped_rows: list,
        file_name: str,
        product_catalog: dict,
        id_khach_hang: str = None,
        merge_key: str = "",
        created_at: str = None,
        final_order_id: str = "",
    ):
        first_row = grouped_rows[0]
        order_shipping_once = ImportModel._get_order_shipping_once(grouped_rows)

        items = []
        for index, row in enumerate(grouped_rows):
            shipping_override = order_shipping_once if index == 0 else 0.0
            items.append(
                ImportModel._parse_item(
                    row=row,
                    product_catalog=product_catalog,
                    shipping_override=shipping_override,
                )
            )

        subtotal = sum(item["quantity"] * item["price"] for item in items)
        total_tax = sum(item["tax"] for item in items)
        total = subtotal + total_tax + order_shipping_once

        uploader_info = ImportModel._build_uploader_info(id_khach_hang)

        return {
            "orderId": final_order_id,
            "sourceOrderIds": [],
            "mergeKey": merge_key,
            "customer": {
                "name": first_row.get("customer", ""),
                "phone": first_row.get("phone", ""),
                "email": first_row.get("email", ""),
                "address": ImportModel._parse_address(first_row.get("address", "")),
            },
            "items": items,
            "status": "Unpaid",
            "total": round(total, 2),
            "id_khach_hang": id_khach_hang,
            "uploadedByName": uploader_info["uploadedByName"],
            "uploadedByEmail": uploader_info["uploadedByEmail"],
            "createdAt": created_at or datetime.utcnow().isoformat(),
            "importSource": file_name,
            "tracking": "",
            "trackingNumber": "",
            "trackingLink": "",
        }

    @staticmethod
    def _pick_keeper_doc(candidates: list, merge_key: str):
        if not candidates:
            return None

        exact_merge = [doc for doc in candidates if doc.get("mergeKey") == merge_key]
        pool = exact_merge if exact_merge else candidates

        def sort_key(doc):
            created = str(doc.get("createdAt") or "")
            return (created, str(doc.get("_id")))

        pool = sorted(pool, key=sort_key)
        return pool[0]

    @staticmethod
    def _generate_order_id(vn_now=None):
        vn_now = vn_now or datetime.now(VN_TZ)
        timestamp_part = vn_now.strftime("%Y%m%d%H%M%S")
        milli_part = f"{vn_now.microsecond // 10000:03d}"
        prefix = f"ORD-{timestamp_part}{milli_part}-"

        regex = re.compile(rf"^{re.escape(prefix)}(\d{{4}})$")

        latest_docs = orders_collection.find(
            {"orderId": {"$regex": rf"^{re.escape(prefix)}\d{{4}}$"}},
            {"orderId": 1}
        )

        max_seq = 0
        for doc in latest_docs:
            order_id = str(doc.get("orderId") or "")
            match = regex.match(order_id)
            if match:
                max_seq = max(max_seq, int(match.group(1)))

        next_seq = max_seq + 1

        while True:
            generated = f"{prefix}{str(next_seq).zfill(4)}"
            existed = orders_collection.find_one({"orderId": generated}, {"_id": 1})
            if not existed:
                return generated
            next_seq += 1

    @staticmethod
    def process_import(file_name: str, file_size: int, rows: list, id_khach_hang: str = None):
        import_errors = []
        valid_row_count = 0

        grouped_orders = defaultdict(list)
        product_catalog = ProductCatalogModel.get_product_catalog_map()

        for idx, row in enumerate(rows):
            errors = ImportModel._validate_basic_row(row)

            if errors:
                import_errors.append({
                    "row": idx + 1,
                    "message": ", ".join(errors)
                })
                continue

            normalized_row, variant_error = ImportModel._validate_and_enrich_row(
                row=row,
                product_catalog=product_catalog,
            )

            if variant_error:
                import_errors.append({
                    "row": idx + 1,
                    "message": variant_error
                })
                continue

            group_key = ImportModel._build_group_key(normalized_row)
            grouped_orders[group_key].append(normalized_row)
            valid_row_count += 1

        inserted_orders = []
        preview_data = []

        for merge_key, grouped_rows in grouped_orders.items():
            candidate_query = {
                "id_khach_hang": id_khach_hang,
                "mergeKey": merge_key,
            }

            existed_docs = list(orders_collection.find(candidate_query))
            keeper_doc = ImportModel._pick_keeper_doc(existed_docs, merge_key)

            final_order_id = (
                str(keeper_doc.get("orderId") or "").strip()
                if keeper_doc
                else ImportModel._generate_order_id()
            )

            order_doc = ImportModel._build_order_doc(
                grouped_rows=grouped_rows,
                file_name=file_name,
                product_catalog=product_catalog,
                id_khach_hang=id_khach_hang,
                merge_key=merge_key,
                created_at=keeper_doc.get("createdAt") if keeper_doc else None,
                final_order_id=final_order_id,
            )

            if keeper_doc:
                orders_collection.update_one(
                    {"_id": keeper_doc["_id"]},
                    {"$set": order_doc},
                )

                duplicate_ids = [
                    doc["_id"]
                    for doc in existed_docs
                    if str(doc["_id"]) != str(keeper_doc["_id"])
                ]
                if duplicate_ids:
                    orders_collection.delete_many({"_id": {"$in": duplicate_ids}})
            else:
                orders_collection.insert_one(order_doc)

            inserted_orders.append(order_doc)

        for doc in inserted_orders[:10]:
            preview_data.append({
                "orderId": doc["orderId"],
                "sourceOrderIds": doc.get("sourceOrderIds", []),
                "customer": doc["customer"]["name"],
                "items": f"{len(doc['items'])} item(s)",
                "address": doc["customer"]["address"]["street"],
            })

        return {
            "importId": f"IMP-{datetime.now().strftime('%Y%m%d%H%M%S')}",
            "totalRows": len(rows),
            "validRows": valid_row_count,
            "invalidRows": len(import_errors),
            "groupedOrderCount": len(grouped_orders),
            "insertedOrderCount": len(inserted_orders),
            "errors": import_errors,
            "previewData": preview_data,
        }