"""Indexer for adding documents to Qdrant."""

import hashlib
import json
from pathlib import Path
from typing import Any

from qdrant_client import QdrantClient
from qdrant_client.models import (
    Distance,
    VectorParams,
    SparseVectorParams,
    SparseIndexParams,
    PointStruct,
    PayloadSchemaType,
    SparseVector,
)

from .config import SearchConfig, IndexConfig
from .models import IndexDocument
from .embeddings import create_embedding_client, EmbeddingClient
from .synonyms import SynonymExpander
from .weights import get_word_weight, STOPWORDS


class QdrantIndexer:
    """Indexes documents into Qdrant with hybrid vectors (dense + sparse/BM25)."""

    def __init__(
        self,
        search_config: SearchConfig,
        index_config: IndexConfig | None = None,
        embedding_client: EmbeddingClient | None = None,
        synonym_expander: SynonymExpander | None = None,
    ):
        """Initialize the indexer.

        Args:
            search_config: Search/Qdrant configuration.
            index_config: Indexing configuration (fields, indexes).
            embedding_client: Custom embedding client. Created from config if None.
            synonym_expander: Custom synonym expander. Uses default if None.
        """
        self.search_config = search_config
        self.index_config = index_config or IndexConfig()
        self.client = QdrantClient(url=search_config.qdrant_url)
        self.embedding_client = embedding_client or create_embedding_client(search_config)
        self.synonym_expander = synonym_expander or SynonymExpander()

    def create_collection(self, reset: bool = False) -> None:
        """Create the Qdrant collection with required vector configuration.

        Args:
            reset: If True, delete existing collection first.
        """
        collection_name = self.search_config.collection_name

        if reset:
            try:
                self.client.delete_collection(collection_name)
                print(f"Deleted existing collection: {collection_name}")
            except Exception:
                pass

        collections = self.client.get_collections().collections
        if any(c.name == collection_name for c in collections):
            print(f"Collection {collection_name} already exists")
            return

        # Dense vector config
        vectors_config: dict[str, VectorParams] = {
            "dense": VectorParams(
                size=self.search_config.embedding_dims,
                distance=Distance.COSINE,
            ),
        }

        # Matryoshka vector configs for progressive search
        for dim in self.search_config.matryoshka_dims:
            vectors_config[f"matryoshka_{dim}"] = VectorParams(
                size=dim,
                distance=Distance.COSINE,
            )

        # Sparse vector config for BM25-like text search
        sparse_vectors_config = {
            "sparse": SparseVectorParams(
                index=SparseIndexParams(on_disk=False)
            )
        }

        self.client.create_collection(
            collection_name=collection_name,
            vectors_config=vectors_config,
            sparse_vectors_config=sparse_vectors_config,
        )

        print(f"Created collection: {collection_name}")
        print(f"  - Dense vector: {self.search_config.embedding_dims} dims")
        print(f"  - Matryoshka: {self.search_config.matryoshka_dims}")
        print(f"  - Sparse (BM25): enabled")

    def create_payload_indexes(self) -> None:
        """Create payload indexes for filtering based on index_config."""
        type_map = {
            "keyword": PayloadSchemaType.KEYWORD,
            "integer": PayloadSchemaType.INTEGER,
            "float": PayloadSchemaType.FLOAT,
            "bool": PayloadSchemaType.BOOL,
            "text": PayloadSchemaType.TEXT,
        }

        for field_name, field_type in self.index_config.payload_indexes.items():
            schema_type = type_map.get(field_type.lower())
            if not schema_type:
                raise ValueError(f"Unknown field type: {field_type}. Use: {list(type_map.keys())}")

            self.client.create_payload_index(
                collection_name=self.search_config.collection_name,
                field_name=field_name,
                field_schema=schema_type,
            )
            print(f"Created payload index: {field_name} ({field_type})")

    def _extract_text_fields(self, data: dict[str, Any]) -> str:
        """Extract and combine text fields for embedding and BM25.

        Args:
            data: Document data dictionary.

        Returns:
            Combined text from configured text_fields.
        """
        parts: list[str] = []

        for field in self.index_config.text_fields:
            value = data.get(field)
            if value is None:
                continue

            if isinstance(value, list):
                # Handle list fields (e.g., tags)
                parts.extend(str(v) for v in value if v)
            elif isinstance(value, str):
                parts.append(value)
            else:
                parts.append(str(value))

        text = " ".join(filter(None, parts))

        # Enrich with synonyms for better recall
        return self.synonym_expander.enrich_text(text)

    def _text_to_sparse_vector(self, text: str) -> tuple[list[int], list[float]]:
        """Convert text to sparse vector with IDF-like weighting.

        Uses MD5 hash of words to create sparse indices.
        Applies IDF-like weights to reduce impact of common words.
        """
        words = text.lower().split()
        word_counts: dict[str, int] = {}

        for word in words:
            # Keep only alphanumeric characters
            word = "".join(c for c in word if c.isalnum())
            if len(word) > 1 and word not in STOPWORDS:
                word_counts[word] = word_counts.get(word, 0) + 1

        # Aggregate by index with IDF-like weighting
        index_values: dict[int, float] = {}

        for word, count in word_counts.items():
            weight = get_word_weight(word)
            if weight > 0:
                # Hash word to index (100k buckets)
                idx = int(hashlib.md5(word.encode()).hexdigest()[:8], 16) % 100000
                index_values[idx] = index_values.get(idx, 0) + float(count) * weight

        indices = list(index_values.keys())
        values = list(index_values.values())

        return indices, values

    def index_documents(
        self,
        documents: list[IndexDocument],
        batch_size: int | None = None,
    ) -> int:
        """Index documents into Qdrant.

        Args:
            documents: List of documents to index.
            batch_size: Number of documents per batch. Uses config default if None.

        Returns:
            Number of documents indexed.
        """
        batch_size = batch_size or self.index_config.batch_size
        total_indexed = 0

        for batch_start in range(0, len(documents), batch_size):
            batch = documents[batch_start : batch_start + batch_size]
            print(f"Processing batch {batch_start // batch_size + 1}...")

            # Prepare search texts from configured fields
            search_texts = [self._extract_text_fields(doc.payload) for doc in batch]

            # Get embeddings
            print("  Getting embeddings...")
            embeddings = self.embedding_client.get_embeddings_batch(search_texts)

            points = []
            for i, (doc, search_text, embedding) in enumerate(
                zip(batch, search_texts, embeddings)
            ):
                # Build vector dict with dense and matryoshka vectors
                vectors: dict[str, Any] = {"dense": embedding}
                for dim in self.search_config.matryoshka_dims:
                    vectors[f"matryoshka_{dim}"] = embedding[:dim]

                # Build sparse vector for BM25
                sparse_indices, sparse_values = self._text_to_sparse_vector(search_text)

                # Build payload (original data + optional search text)
                payload = {**doc.payload}
                if self.index_config.store_text:
                    payload["_search_text"] = search_text

                # Create point
                point_id = batch_start + i
                if isinstance(doc.id, int):
                    point_id = doc.id
                elif isinstance(doc.id, str) and doc.id.isdigit():
                    point_id = int(doc.id)

                point = PointStruct(
                    id=point_id,
                    vector=vectors,
                    payload=payload,
                )

                point.vector["sparse"] = SparseVector(
                    indices=sparse_indices,
                    values=sparse_values,
                )

                points.append(point)

            self.client.upsert(
                collection_name=self.search_config.collection_name,
                points=points,
            )

            total_indexed += len(batch)
            print(f"  Indexed {total_indexed}/{len(documents)}")

        return total_indexed

    def index_from_jsonl(self, jsonl_path: Path | str) -> int:
        """Index documents from a JSONL file.

        Each line should be a JSON object. The id_field from index_config
        is used as document ID, all other fields go to payload.

        Args:
            jsonl_path: Path to JSONL file.

        Returns:
            Number of documents indexed.
        """
        path = Path(jsonl_path)
        if not path.exists():
            raise FileNotFoundError(f"JSONL file not found: {path}")

        documents = []
        with open(path, "r") as f:
            for line in f:
                if line.strip():
                    data = json.loads(line)
                    doc_id = data.get(self.index_config.id_field, len(documents))
                    documents.append(IndexDocument(
                        id=doc_id,
                        content="",  # Content is built from text_fields
                        payload=data,
                    ))

        print(f"Loaded {len(documents)} documents from {path}")
        print(f"Text fields for BM25: {self.index_config.text_fields}")
        print(f"Payload indexes: {self.index_config.payload_indexes}")

        return self.index_documents(documents)

    def index_from_json(self, json_path: Path | str, items_key: str | None = None) -> int:
        """Index documents from a JSON file.

        Args:
            json_path: Path to JSON file.
            items_key: Key containing array of items. If None, expects root array.

        Returns:
            Number of documents indexed.
        """
        path = Path(json_path)
        if not path.exists():
            raise FileNotFoundError(f"JSON file not found: {path}")

        with open(path, "r") as f:
            data = json.load(f)

        if items_key:
            items = data[items_key]
        else:
            items = data if isinstance(data, list) else [data]

        documents = []
        for item in items:
            doc_id = item.get(self.index_config.id_field, len(documents))
            documents.append(IndexDocument(
                id=doc_id,
                content="",
                payload=item,
            ))

        print(f"Loaded {len(documents)} documents from {path}")
        return self.index_documents(documents)

    def get_collection_info(self) -> dict[str, Any]:
        """Get information about the collection."""
        info = self.client.get_collection(self.search_config.collection_name)
        return {
            "name": self.search_config.collection_name,
            "points_count": info.points_count,
            "status": str(info.status),
        }
