144 lines
5.2 KiB
Python
144 lines
5.2 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Генерация embeddings и загрузка в ChromaDB.
|
||
Модель: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
|
||
"""
|
||
|
||
import json
|
||
import sys
|
||
import os
|
||
import time
|
||
from pathlib import Path
|
||
|
||
CHROMA_PATH = str(Path(__file__).parent.parent / "data" / "chromadb")
|
||
DATA_FILE = Path(__file__).parent.parent / "data" / "messages.jsonl"
|
||
COLLECTION_NAME = "snowbike_embeddings"
|
||
MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||
BATCH_SIZE = 64
|
||
MAX_TEXT_LEN = 512 # Обрезаем очень длинные тексты
|
||
|
||
|
||
def load_messages():
|
||
"""Загружаем сообщения из JSONL."""
|
||
messages = []
|
||
if not DATA_FILE.exists():
|
||
print(f"ОШИБКА: файл {DATA_FILE} не найден. Сначала запустите parse_messages.py")
|
||
sys.exit(1)
|
||
|
||
with open(DATA_FILE, "r", encoding="utf-8") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if line:
|
||
msg = json.loads(line)
|
||
# Пропускаем сообщения с очень коротким текстом (< 5 символов)
|
||
if len(msg.get("text", "")) >= 5:
|
||
messages.append(msg)
|
||
|
||
print(f"Загружено {len(messages)} сообщений (с текстом >= 5 символов)")
|
||
return messages
|
||
|
||
|
||
def get_or_create_collection(client):
|
||
"""Получаем или создаём коллекцию ChromaDB."""
|
||
try:
|
||
collection = client.get_collection(COLLECTION_NAME)
|
||
count = collection.count()
|
||
print(f"Коллекция '{COLLECTION_NAME}' уже существует, {count} документов")
|
||
return collection, count
|
||
except Exception:
|
||
collection = client.create_collection(
|
||
name=COLLECTION_NAME,
|
||
metadata={"description": "Snowbike Russia Telegram messages embeddings"}
|
||
)
|
||
print(f"Коллекция '{COLLECTION_NAME}' создана")
|
||
return collection, 0
|
||
|
||
|
||
def main():
|
||
print("=== ChromaDB Индексация ===")
|
||
|
||
# Импортируем здесь чтобы показать ошибки явно
|
||
try:
|
||
import chromadb
|
||
from sentence_transformers import SentenceTransformer
|
||
except ImportError as e:
|
||
print(f"ОШИБКА импорта: {e}")
|
||
print("Установите: pip install chromadb sentence-transformers")
|
||
sys.exit(1)
|
||
|
||
# Загружаем данные
|
||
messages = load_messages()
|
||
|
||
# Загружаем модель
|
||
print(f"\nЗагружаем модель {MODEL_NAME}...")
|
||
model = SentenceTransformer(MODEL_NAME)
|
||
print("Модель загружена")
|
||
|
||
# Подключаемся к ChromaDB
|
||
print(f"\nПодключаемся к ChromaDB: {CHROMA_PATH}")
|
||
os.makedirs(CHROMA_PATH, exist_ok=True)
|
||
client = chromadb.PersistentClient(path=CHROMA_PATH)
|
||
|
||
collection, existing_count = get_or_create_collection(client)
|
||
|
||
# Получаем уже проиндексированные IDs (если есть)
|
||
if existing_count > 0:
|
||
print(f"Уже проиндексировано {existing_count} документов")
|
||
existing_ids = set(collection.get(include=[])["ids"])
|
||
messages = [m for m in messages if str(m["id"]) not in existing_ids]
|
||
print(f"Осталось добавить: {len(messages)} документов")
|
||
|
||
if not messages:
|
||
print("Нечего индексировать, всё уже есть!")
|
||
return
|
||
|
||
total = len(messages)
|
||
indexed = 0
|
||
start_time = time.time()
|
||
|
||
print(f"\nИндексируем {total} сообщений батчами по {BATCH_SIZE}...")
|
||
|
||
for i in range(0, total, BATCH_SIZE):
|
||
batch = messages[i:i + BATCH_SIZE]
|
||
|
||
texts = [m["text"][:MAX_TEXT_LEN] for m in batch]
|
||
ids = [str(m["id"]) for m in batch]
|
||
metadatas = [
|
||
{
|
||
"topic_id": m["topic_id"],
|
||
"topic_title": m["topic_title"],
|
||
"date": m.get("date", ""),
|
||
"from_id": str(m.get("from_id", "")),
|
||
"month": int(m["date"][5:7]) if m.get("date") and len(m["date"]) >= 7 else 0,
|
||
}
|
||
for m in batch
|
||
]
|
||
|
||
# Генерируем embeddings
|
||
embeddings = model.encode(texts, show_progress_bar=False).tolist()
|
||
|
||
# Добавляем в ChromaDB
|
||
collection.add(
|
||
ids=ids,
|
||
embeddings=embeddings,
|
||
documents=texts,
|
||
metadatas=metadatas,
|
||
)
|
||
|
||
indexed += len(batch)
|
||
elapsed = time.time() - start_time
|
||
speed = indexed / elapsed if elapsed > 0 else 0
|
||
eta = (total - indexed) / speed if speed > 0 else 0
|
||
progress = (indexed / total) * 100
|
||
|
||
print(f" {indexed}/{total} ({progress:.1f}%) | {speed:.0f} msg/s | ETA: {eta:.0f}s", end="\r")
|
||
|
||
elapsed = time.time() - start_time
|
||
print(f"\n\nИндексация завершена за {elapsed:.0f}с")
|
||
print(f"Всего в коллекции: {collection.count()} документов")
|
||
print("\n✓ Готово!")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|