"""Hybrid search implementation using Qdrant."""

import hashlib
import json
import re
from collections import defaultdict
from typing import Any

import httpx
from qdrant_client import QdrantClient
from qdrant_client.models import (
    Filter,
    FieldCondition,
    MatchValue,
    SparseVector,
)

from .config import SearchConfig
from .models import SearchResult, SearchQuery
from .embeddings import create_embedding_client, EmbeddingClient
from .synonyms import SynonymExpander
from .weights import get_word_weight, STOPWORDS


class QdrantHybridSearch:
    """Hybrid search combining dense, sparse, and matryoshka vectors."""

    def __init__(
        self,
        config: SearchConfig | None = None,
        embedding_client: EmbeddingClient | None = None,
        synonym_expander: SynonymExpander | None = None,
    ):
        """Initialize the hybrid search.

        Args:
            config: Search configuration. Uses defaults if None.
            embedding_client: Custom embedding client. Created from config if None.
            synonym_expander: Custom synonym expander. Uses default if None.
        """
        self.config = config or SearchConfig()
        self.client = QdrantClient(url=self.config.qdrant_url)
        self.embedding_client = embedding_client or create_embedding_client(self.config)
        self.synonym_expander = synonym_expander or SynonymExpander()

    def search(
        self,
        query: str | SearchQuery,
        top_k: int | None = None,
        min_score: float | None = None,
        filters: dict[str, Any] | None = None,
        use_reranking: bool | None = None,
    ) -> list[SearchResult]:
        """Perform hybrid search.

        Args:
            query: Search query string or SearchQuery object.
            top_k: Number of results to return. Uses config default if None.
            min_score: Minimum score threshold. Uses config default if None.
            filters: Optional filters for the search.
            use_reranking: Whether to use LLM reranking. Uses config default if None.

        Returns:
            List of search results sorted by relevance.
        """
        # Normalize query
        if isinstance(query, str):
            query = SearchQuery(
                query=query,
                top_k=top_k or self.config.top_k,
                min_score=min_score or self.config.min_score,
                filters=filters or {},
                use_reranking=use_reranking if use_reranking is not None else self.config.use_reranking,
            )
        else:
            if top_k is not None:
                query.top_k = top_k
            if min_score is not None:
                query.min_score = min_score
            if filters is not None:
                query.filters = filters
            if use_reranking is not None:
                query.use_reranking = use_reranking

        # Build Qdrant filter
        search_filter = self._build_filter(query.filters)

        # Enrich query with synonyms for better semantic matching
        enriched_query = self.synonym_expander.enrich_text(query.query)

        # Get query embedding from enriched text
        query_embedding = self.embedding_client.get_embedding(enriched_query)

        # Matryoshka progressive search
        matryoshka_results = self._matryoshka_search(
            query_embedding, search_filter, limit=25
        )

        # Hybrid dense + sparse search
        hybrid_results = self._hybrid_search(
            query.query, query_embedding, search_filter, limit=25
        )

        # Merge results
        all_results = self._merge_results(matryoshka_results, hybrid_results)

        # Optional reranking
        if query.use_reranking and len(all_results) > query.top_k:
            all_results = self._rerank(query.query, all_results, top_k=query.top_k * 2)

        # Filter by min_score and limit
        filtered = [r for r in all_results if r.score >= query.min_score]
        return filtered[:query.top_k]

    def _build_filter(self, filters: dict[str, Any]) -> Filter | None:
        """Build Qdrant filter from dictionary."""
        if not filters:
            return None

        conditions = []
        for key, value in filters.items():
            conditions.append(
                FieldCondition(
                    key=key,
                    match=MatchValue(value=value),
                )
            )

        return Filter(must=conditions) if conditions else None

    def _matryoshka_search(
        self,
        embedding: list[float],
        search_filter: Filter | None,
        limit: int = 25,
    ) -> list[SearchResult]:
        """Progressive search using matryoshka embeddings."""
        dims = sorted(self.config.matryoshka_dims)

        # Start with smallest dimension for fast candidate retrieval
        current_limit = 100
        for i, dim in enumerate(dims):
            results = self.client.query_points(
                collection_name=self.config.collection_name,
                query=embedding[:dim],
                using=f"matryoshka_{dim}",
                query_filter=search_filter,
                limit=current_limit,
                with_payload=True,
            ).points

            if not results:
                return []

            # Reduce limit for higher dimensions
            current_limit = max(limit, current_limit // 2)

        return [
            SearchResult.from_qdrant_point(r, r.score)
            for r in results
        ]

    def _hybrid_search(
        self,
        query: str,
        embedding: list[float],
        search_filter: Filter | None,
        limit: int = 25,
    ) -> list[SearchResult]:
        """Combine dense and sparse search with RRF fusion."""
        # Dense search
        dense_results = self.client.query_points(
            collection_name=self.config.collection_name,
            query=embedding,
            using="dense",
            query_filter=search_filter,
            limit=limit,
            with_payload=True,
        ).points

        # Sparse search
        sparse_vector = self._text_to_sparse(query)

        try:
            sparse_results = self.client.query_points(
                collection_name=self.config.collection_name,
                query=sparse_vector,
                using="sparse",
                query_filter=search_filter,
                limit=limit,
                with_payload=True,
            ).points
        except Exception as e:
            print(f"[DEBUG] Sparse search error: {e}")
            sparse_results = []

        # RRF fusion
        return self._rrf_fusion(
            [
                [(r.payload, r.score, r) for r in dense_results],
                [(r.payload, r.score, r) for r in sparse_results],
            ],
            k=self.config.rrf_k,
            limit=limit,
        )

    def _text_to_sparse(self, text: str) -> SparseVector:
        """Convert text to sparse vector with synonym enrichment and IDF-like weighting."""
        # Use the same enrichment method as indexing for consistency
        enriched_text = self.synonym_expander.enrich_text(text)

        words = enriched_text.lower().split()
        word_counts: dict[str, int] = {}

        for word in words:
            word = "".join(c for c in word if c.isalnum())
            if len(word) > 1 and word not in STOPWORDS:
                word_counts[word] = word_counts.get(word, 0) + 1

        # Aggregate by index with IDF-like weighting
        index_values: dict[int, float] = {}

        for word, count in word_counts.items():
            weight = get_word_weight(word)
            if weight > 0:
                idx = int(hashlib.md5(word.encode()).hexdigest()[:8], 16) % 100000
                index_values[idx] = index_values.get(idx, 0) + float(count) * weight

        return SparseVector(
            indices=list(index_values.keys()),
            values=list(index_values.values()),
        )

    def _rrf_fusion(
        self,
        result_lists: list[list[tuple[dict, float, Any]]],
        k: int = 60,
        limit: int = 25,
    ) -> list[SearchResult]:
        """Reciprocal Rank Fusion of multiple result lists."""
        scores: dict[str, float] = defaultdict(float)
        points: dict[str, Any] = {}

        for results in result_lists:
            for rank, (payload, _, point) in enumerate(results):
                point_id = str(point.id)  # Use actual point ID
                scores[point_id] += 1.0 / (k + rank + 1)
                points[point_id] = point

        sorted_ids = sorted(scores.keys(), key=lambda x: scores[x], reverse=True)

        # Normalize RRF scores to 0-1 range
        if sorted_ids:
            max_score = scores[sorted_ids[0]]
            return [
                SearchResult.from_qdrant_point(points[pid], scores[pid] / max_score)
                for pid in sorted_ids[:limit]
            ]
        return []

    def _merge_results(
        self,
        results1: list[SearchResult],
        results2: list[SearchResult],
    ) -> list[SearchResult]:
        """Merge two result lists using RRF fusion."""
        scores: dict[str, float] = defaultdict(float)
        results_by_id: dict[str, SearchResult] = {}
        k = self.config.rrf_k

        # Apply RRF to both lists
        for results in [results1, results2]:
            for rank, r in enumerate(results):
                scores[r.id] += 1.0 / (k + rank + 1)
                if r.id not in results_by_id:
                    results_by_id[r.id] = r

        sorted_ids = sorted(scores.keys(), key=lambda x: scores[x], reverse=True)

        # Normalize and return
        if sorted_ids:
            max_score = scores[sorted_ids[0]]
            merged = []
            for pid in sorted_ids:
                result = results_by_id[pid]
                result.score = scores[pid] / max_score
                merged.append(result)
            return merged
        return []

    def _rerank(
        self,
        query: str,
        results: list[SearchResult],
        top_k: int = 20,
    ) -> list[SearchResult]:
        """Rerank results using LLM."""
        if not self.config.embedding_api_key:
            return results

        candidates = results[:min(len(results), 30)]

        # Build document list for reranking
        docs_text = "\n".join(
            f"{i+1}. [{r.id}] {r.title or r.get('summary', '')} - {r.description[:100] if r.description else ''}"
            for i, r in enumerate(candidates)
        )

        prompt = f"""Rate the relevance of each item to the query.
Query: "{query}"

Items:
{docs_text}

Return a JSON array of item indices (1-based) sorted by relevance, most relevant first.
Only return the JSON array, nothing else. Example: [3, 1, 7, 2, ...]"""

        try:
            response = httpx.post(
                f"{self.config.embedding_base_url}/chat/completions",
                headers={
                    "Authorization": f"Bearer {self.config.embedding_api_key}",
                    "Content-Type": "application/json",
                },
                json={
                    "model": self.config.rerank_model,
                    "messages": [{"role": "user", "content": prompt}],
                    "temperature": 0,
                },
                timeout=30.0,
            )
            response.raise_for_status()
            data = response.json()

            content = data["choices"][0]["message"]["content"]
            match = re.search(r"\[[\d,\s]+\]", content)

            if match:
                indices = json.loads(match.group())
                reranked: list[SearchResult] = []
                seen: set[int] = set()

                for idx in indices:
                    if 1 <= idx <= len(candidates) and idx not in seen:
                        seen.add(idx)
                        result = candidates[idx - 1]
                        result.score = 1.0 - (len(reranked) / len(indices))
                        reranked.append(result)

                # Add remaining candidates
                for r in candidates:
                    if r.id not in {rr.id for rr in reranked}:
                        reranked.append(r)

                return reranked[:top_k]

        except Exception as e:
            print(f"Reranking failed: {e}")

        return results[:top_k]


# Module-level singleton
_search_instance: QdrantHybridSearch | None = None


def get_search(config: SearchConfig | None = None) -> QdrantHybridSearch:
    """Get or create the default search instance."""
    global _search_instance
    if _search_instance is None:
        _search_instance = QdrantHybridSearch(config)
    return _search_instance
