import ollama import numpy as np from pathlib import Path import pickle from typing import List, Tuple # Configuration EMBEDDING_MODEL = 'hf.co/CompendiumLabs/bge-base-en-v1.5-gguf' LANGUAGE_MODEL = 'hf.co/bartowski/Llama-3.2-1B-Instruct-GGUF' DATASET_PATH = 'cat-facts.txt' CACHE_PATH = 'vector_db.pkl' class VectorDatabase: """Simple vector database with caching support""" def __init__(self): self.chunks = [] self.embeddings = [] def add(self, chunk: str, embedding: List[float]): self.chunks.append(chunk.strip()) self.embeddings.append(np.array(embedding)) def save(self, path: str): """Cache embeddings to avoid recomputing""" with open(path, 'wb') as f: pickle.dump({'chunks': self.chunks, 'embeddings': self.embeddings}, f) print(f'Database cached to {path}') def load(self, path: str): """Load cached embeddings""" with open(path, 'rb') as f: data = pickle.load(f) self.chunks = data['chunks'] self.embeddings = data['embeddings'] print(f'Loaded {len(self.chunks)} entries from cache') def __len__(self): return len(self.chunks) def load_dataset(filepath: str) -> List[str]: """Load and preprocess dataset""" try: with open(filepath, 'r', encoding='utf-8') as file: # Filter empty lines and strip whitespace chunks = [line.strip() for line in file if line.strip()] print(f'Loaded {len(chunks)} entries from {filepath}') return chunks except FileNotFoundError: print(f'Error: {filepath} not found') return [] def build_database(chunks: List[str], use_cache: bool = True) -> VectorDatabase: """Build or load vector database""" db = VectorDatabase() # Try loading from cache if use_cache and Path(CACHE_PATH).exists(): db.load(CACHE_PATH) return db # Build new database print('Building vector database...') for i, chunk in enumerate(chunks): try: response = ollama.embed(model=EMBEDDING_MODEL, input=chunk) embedding = response['embeddings'][0] db.add(chunk, embedding) print(f'Processed {i + 1}/{len(chunks)}', end='\r') except Exception as e: print(f'\nError embedding chunk {i + 1}: {e}') continue print(f'\nCompleted: {len(db)} chunks embedded') # Cache for future use if use_cache: db.save(CACHE_PATH) return db def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: """Compute cosine similarity using numpy for efficiency""" return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) def retrieve(db: VectorDatabase, query: str, top_n: int = 3) -> List[Tuple[str, float]]: """Retrieve most relevant chunks for a query""" try: query_embedding = ollama.embed(model=EMBEDDING_MODEL, input=query)['embeddings'][0] query_embedding = np.array(query_embedding) except Exception as e: print(f'Error generating query embedding: {e}') return [] # Calculate similarities similarities = [ (chunk, cosine_similarity(query_embedding, emb)) for chunk, emb in zip(db.chunks, db.embeddings) ] # Sort and return top N similarities.sort(key=lambda x: x[1], reverse=True) return similarities[:top_n] def generate_response(query: str, context: List[Tuple[str, float]]) -> None: """Generate streaming response using retrieved context""" if not context: print('No relevant context found') return # Build instruction with relevance scores context_str = '\n'.join([ f'- [Relevance: {sim:.2f}] {chunk}' for chunk, sim in context ]) instruction = f'''You are a helpful assistant. Answer the question using ONLY the provided context. If the context doesn't contain relevant information, say so clearly. Context: {context_str} Question: {query}''' try: stream = ollama.chat( model=LANGUAGE_MODEL, messages=[{'role': 'user', 'content': instruction}], stream=True, ) print('\nChatbot response:') for chunk in stream: print(chunk['message']['content'], end='', flush=True) print('\n') except Exception as e: print(f'Error generating response: {e}') def main(): """Main RAG pipeline""" # Load dataset dataset = load_dataset(DATASET_PATH) if not dataset: return # Build or load database db = build_database(dataset, use_cache=True) # Interactive query loop print('\n=== RAG System Ready ===') print('Type "quit" to exit\n') while True: query = input('Ask a question: ').strip() if query.lower() in ['quit', 'exit', 'q']: break if not query: continue # Retrieve relevant chunks results = retrieve(db, query, top_n=3) # Show retrieved context print('\nRetrieved context:') for chunk, sim in results: print(f' [{sim:.3f}] {chunk[:100]}...' if len(chunk) > 100 else f' [{sim:.3f}] {chunk}') # Generate response generate_response(query, results) if __name__ == '__main__': main()