183 lines
5.2 KiB
Python
183 lines
5.2 KiB
Python
|
|
import ollama
|
||
|
|
import numpy as np
|
||
|
|
from pathlib import Path
|
||
|
|
import pickle
|
||
|
|
from typing import List, Tuple
|
||
|
|
|
||
|
|
# Configuration
|
||
|
|
EMBEDDING_MODEL = 'hf.co/CompendiumLabs/bge-base-en-v1.5-gguf'
|
||
|
|
LANGUAGE_MODEL = 'hf.co/bartowski/Llama-3.2-1B-Instruct-GGUF'
|
||
|
|
DATASET_PATH = 'cat-facts.txt'
|
||
|
|
CACHE_PATH = 'vector_db.pkl'
|
||
|
|
|
||
|
|
|
||
|
|
class VectorDatabase:
|
||
|
|
"""Simple vector database with caching support"""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.chunks = []
|
||
|
|
self.embeddings = []
|
||
|
|
|
||
|
|
def add(self, chunk: str, embedding: List[float]):
|
||
|
|
self.chunks.append(chunk.strip())
|
||
|
|
self.embeddings.append(np.array(embedding))
|
||
|
|
|
||
|
|
def save(self, path: str):
|
||
|
|
"""Cache embeddings to avoid recomputing"""
|
||
|
|
with open(path, 'wb') as f:
|
||
|
|
pickle.dump({'chunks': self.chunks, 'embeddings': self.embeddings}, f)
|
||
|
|
print(f'Database cached to {path}')
|
||
|
|
|
||
|
|
def load(self, path: str):
|
||
|
|
"""Load cached embeddings"""
|
||
|
|
with open(path, 'rb') as f:
|
||
|
|
data = pickle.load(f)
|
||
|
|
self.chunks = data['chunks']
|
||
|
|
self.embeddings = data['embeddings']
|
||
|
|
print(f'Loaded {len(self.chunks)} entries from cache')
|
||
|
|
|
||
|
|
def __len__(self):
|
||
|
|
return len(self.chunks)
|
||
|
|
|
||
|
|
|
||
|
|
def load_dataset(filepath: str) -> List[str]:
|
||
|
|
"""Load and preprocess dataset"""
|
||
|
|
try:
|
||
|
|
with open(filepath, 'r', encoding='utf-8') as file:
|
||
|
|
# Filter empty lines and strip whitespace
|
||
|
|
chunks = [line.strip() for line in file if line.strip()]
|
||
|
|
print(f'Loaded {len(chunks)} entries from {filepath}')
|
||
|
|
return chunks
|
||
|
|
except FileNotFoundError:
|
||
|
|
print(f'Error: {filepath} not found')
|
||
|
|
return []
|
||
|
|
|
||
|
|
|
||
|
|
def build_database(chunks: List[str], use_cache: bool = True) -> VectorDatabase:
|
||
|
|
"""Build or load vector database"""
|
||
|
|
db = VectorDatabase()
|
||
|
|
|
||
|
|
# Try loading from cache
|
||
|
|
if use_cache and Path(CACHE_PATH).exists():
|
||
|
|
db.load(CACHE_PATH)
|
||
|
|
return db
|
||
|
|
|
||
|
|
# Build new database
|
||
|
|
print('Building vector database...')
|
||
|
|
for i, chunk in enumerate(chunks):
|
||
|
|
try:
|
||
|
|
response = ollama.embed(model=EMBEDDING_MODEL, input=chunk)
|
||
|
|
embedding = response['embeddings'][0]
|
||
|
|
db.add(chunk, embedding)
|
||
|
|
print(f'Processed {i + 1}/{len(chunks)}', end='\r')
|
||
|
|
except Exception as e:
|
||
|
|
print(f'\nError embedding chunk {i + 1}: {e}')
|
||
|
|
continue
|
||
|
|
|
||
|
|
print(f'\nCompleted: {len(db)} chunks embedded')
|
||
|
|
|
||
|
|
# Cache for future use
|
||
|
|
if use_cache:
|
||
|
|
db.save(CACHE_PATH)
|
||
|
|
|
||
|
|
return db
|
||
|
|
|
||
|
|
|
||
|
|
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||
|
|
"""Compute cosine similarity using numpy for efficiency"""
|
||
|
|
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||
|
|
|
||
|
|
|
||
|
|
def retrieve(db: VectorDatabase, query: str, top_n: int = 3) -> List[Tuple[str, float]]:
|
||
|
|
"""Retrieve most relevant chunks for a query"""
|
||
|
|
try:
|
||
|
|
query_embedding = ollama.embed(model=EMBEDDING_MODEL, input=query)['embeddings'][0]
|
||
|
|
query_embedding = np.array(query_embedding)
|
||
|
|
except Exception as e:
|
||
|
|
print(f'Error generating query embedding: {e}')
|
||
|
|
return []
|
||
|
|
|
||
|
|
# Calculate similarities
|
||
|
|
similarities = [
|
||
|
|
(chunk, cosine_similarity(query_embedding, emb))
|
||
|
|
for chunk, emb in zip(db.chunks, db.embeddings)
|
||
|
|
]
|
||
|
|
|
||
|
|
# Sort and return top N
|
||
|
|
similarities.sort(key=lambda x: x[1], reverse=True)
|
||
|
|
return similarities[:top_n]
|
||
|
|
|
||
|
|
|
||
|
|
def generate_response(query: str, context: List[Tuple[str, float]]) -> None:
|
||
|
|
"""Generate streaming response using retrieved context"""
|
||
|
|
if not context:
|
||
|
|
print('No relevant context found')
|
||
|
|
return
|
||
|
|
|
||
|
|
# Build instruction with relevance scores
|
||
|
|
context_str = '\n'.join([
|
||
|
|
f'- [Relevance: {sim:.2f}] {chunk}'
|
||
|
|
for chunk, sim in context
|
||
|
|
])
|
||
|
|
|
||
|
|
instruction = f'''You are a helpful assistant. Answer the question using ONLY the provided context.
|
||
|
|
If the context doesn't contain relevant information, say so clearly.
|
||
|
|
|
||
|
|
Context:
|
||
|
|
{context_str}
|
||
|
|
|
||
|
|
Question: {query}'''
|
||
|
|
|
||
|
|
try:
|
||
|
|
stream = ollama.chat(
|
||
|
|
model=LANGUAGE_MODEL,
|
||
|
|
messages=[{'role': 'user', 'content': instruction}],
|
||
|
|
stream=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
print('\nChatbot response:')
|
||
|
|
for chunk in stream:
|
||
|
|
print(chunk['message']['content'], end='', flush=True)
|
||
|
|
print('\n')
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f'Error generating response: {e}')
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
"""Main RAG pipeline"""
|
||
|
|
# Load dataset
|
||
|
|
dataset = load_dataset(DATASET_PATH)
|
||
|
|
if not dataset:
|
||
|
|
return
|
||
|
|
|
||
|
|
# Build or load database
|
||
|
|
db = build_database(dataset, use_cache=True)
|
||
|
|
|
||
|
|
# Interactive query loop
|
||
|
|
print('\n=== RAG System Ready ===')
|
||
|
|
print('Type "quit" to exit\n')
|
||
|
|
|
||
|
|
while True:
|
||
|
|
query = input('Ask a question: ').strip()
|
||
|
|
|
||
|
|
if query.lower() in ['quit', 'exit', 'q']:
|
||
|
|
break
|
||
|
|
|
||
|
|
if not query:
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Retrieve relevant chunks
|
||
|
|
results = retrieve(db, query, top_n=3)
|
||
|
|
|
||
|
|
# Show retrieved context
|
||
|
|
print('\nRetrieved context:')
|
||
|
|
for chunk, sim in results:
|
||
|
|
print(f' [{sim:.3f}] {chunk[:100]}...' if len(chunk) > 100 else f' [{sim:.3f}] {chunk}')
|
||
|
|
|
||
|
|
# Generate response
|
||
|
|
generate_response(query, results)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
main()
|