"""Search API routes — keyword + semantic search.""" from __future__ import annotations from fastapi import APIRouter, Query from core import restaurant, vector, cache from core.db import conn router = APIRouter() @router.get("") def search_restaurants( q: str = Query(..., min_length=1), mode: str = Query("keyword", pattern="^(keyword|semantic|hybrid)$"), limit: int = Query(20, le=100), ): """Search restaurants by keyword, semantic similarity, or hybrid.""" key = cache.make_key("search", f"q={q}", f"m={mode}", f"l={limit}") cached = cache.get(key) if cached is not None: return cached if mode == "semantic": result = _semantic_search(q, limit) cache.set(key, result) return result elif mode == "hybrid": kw = _keyword_search(q, limit) sem = _semantic_search(q, limit) # merge: keyword results first, then semantic results not already in keyword seen = {r["id"] for r in kw} merged = list(kw) for r in sem: if r["id"] not in seen: merged.append(r) seen.add(r["id"]) result = merged[:limit] cache.set(key, result) return result else: result = _keyword_search(q, limit) cache.set(key, result) return result def _keyword_search(q: str, limit: int) -> list[dict]: # JOIN video_restaurants to also search foods_mentioned and video title sql = """ SELECT DISTINCT r.id, r.name, r.address, r.region, r.latitude, r.longitude, r.cuisine_type, r.price_range, r.google_place_id, r.business_status, r.rating, r.rating_count FROM restaurants r JOIN video_restaurants vr ON vr.restaurant_id = r.id JOIN videos v ON v.id = vr.video_id WHERE r.latitude IS NOT NULL AND (UPPER(r.name) LIKE UPPER(:q) OR UPPER(r.address) LIKE UPPER(:q) OR UPPER(r.region) LIKE UPPER(:q) OR UPPER(r.cuisine_type) LIKE UPPER(:q) OR UPPER(vr.foods_mentioned) LIKE UPPER(:q) OR UPPER(v.title) LIKE UPPER(:q)) FETCH FIRST :lim ROWS ONLY """ pattern = f"%{q}%" with conn() as c: cur = c.cursor() cur.execute(sql, {"q": pattern, "lim": limit}) cols = [d[0].lower() for d in cur.description] rows = [dict(zip(cols, row)) for row in cur.fetchall()] # Attach channel names if rows: _attach_channels(rows) return rows def _semantic_search(q: str, limit: int) -> list[dict]: similar = vector.search_similar(q, top_k=max(30, limit * 3)) if not similar: return [] # Deduplicate by restaurant_id, preserving distance order (best first) seen: set[str] = set() ordered_ids: list[str] = [] for s in similar: rid = s["restaurant_id"] if rid not in seen: seen.add(rid) ordered_ids.append(rid) results = [] for rid in ordered_ids[:limit]: r = restaurant.get_by_id(rid) if r and r.get("latitude"): results.append(r) if results: _attach_channels(results) return results def _attach_channels(rows: list[dict]): """Attach channel names to each restaurant dict.""" ids = [r["id"] for r in rows] placeholders = ", ".join(f":id{i}" for i in range(len(ids))) sql = f""" SELECT DISTINCT vr.restaurant_id, c.channel_name FROM video_restaurants vr JOIN videos v ON v.id = vr.video_id JOIN channels c ON c.id = v.channel_id WHERE vr.restaurant_id IN ({placeholders}) """ params = {f"id{i}": rid for i, rid in enumerate(ids)} ch_map: dict[str, list[str]] = {} with conn() as c: cur = c.cursor() cur.execute(sql, params) for row in cur.fetchall(): ch_map.setdefault(row[0], []).append(row[1]) for r in rows: r["channels"] = ch_map.get(r["id"], [])