RAG-AI/SimpleRAG/RAG_Extension.py

183 lines
5.2 KiB
Python
Raw Normal View History

2025-11-24 15:26:15 -06:00
import ollama
import numpy as np
from pathlib import Path
import pickle
from typing import List, Tuple
# Configuration
EMBEDDING_MODEL = 'hf.co/CompendiumLabs/bge-base-en-v1.5-gguf'
LANGUAGE_MODEL = 'hf.co/bartowski/Llama-3.2-1B-Instruct-GGUF'
DATASET_PATH = 'cat-facts.txt'
CACHE_PATH = 'vector_db.pkl'
class VectorDatabase:
"""Simple vector database with caching support"""
def __init__(self):
self.chunks = []
self.embeddings = []
def add(self, chunk: str, embedding: List[float]):
self.chunks.append(chunk.strip())
self.embeddings.append(np.array(embedding))
def save(self, path: str):
"""Cache embeddings to avoid recomputing"""
with open(path, 'wb') as f:
pickle.dump({'chunks': self.chunks, 'embeddings': self.embeddings}, f)
print(f'Database cached to {path}')
def load(self, path: str):
"""Load cached embeddings"""
with open(path, 'rb') as f:
data = pickle.load(f)
self.chunks = data['chunks']
self.embeddings = data['embeddings']
print(f'Loaded {len(self.chunks)} entries from cache')
def __len__(self):
return len(self.chunks)
def load_dataset(filepath: str) -> List[str]:
"""Load and preprocess dataset"""
try:
with open(filepath, 'r', encoding='utf-8') as file:
# Filter empty lines and strip whitespace
chunks = [line.strip() for line in file if line.strip()]
print(f'Loaded {len(chunks)} entries from {filepath}')
return chunks
except FileNotFoundError:
print(f'Error: {filepath} not found')
return []
def build_database(chunks: List[str], use_cache: bool = True) -> VectorDatabase:
"""Build or load vector database"""
db = VectorDatabase()
# Try loading from cache
if use_cache and Path(CACHE_PATH).exists():
db.load(CACHE_PATH)
return db
# Build new database
print('Building vector database...')
for i, chunk in enumerate(chunks):
try:
response = ollama.embed(model=EMBEDDING_MODEL, input=chunk)
embedding = response['embeddings'][0]
db.add(chunk, embedding)
print(f'Processed {i + 1}/{len(chunks)}', end='\r')
except Exception as e:
print(f'\nError embedding chunk {i + 1}: {e}')
continue
print(f'\nCompleted: {len(db)} chunks embedded')
# Cache for future use
if use_cache:
db.save(CACHE_PATH)
return db
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Compute cosine similarity using numpy for efficiency"""
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def retrieve(db: VectorDatabase, query: str, top_n: int = 3) -> List[Tuple[str, float]]:
"""Retrieve most relevant chunks for a query"""
try:
query_embedding = ollama.embed(model=EMBEDDING_MODEL, input=query)['embeddings'][0]
query_embedding = np.array(query_embedding)
except Exception as e:
print(f'Error generating query embedding: {e}')
return []
# Calculate similarities
similarities = [
(chunk, cosine_similarity(query_embedding, emb))
for chunk, emb in zip(db.chunks, db.embeddings)
]
# Sort and return top N
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:top_n]
def generate_response(query: str, context: List[Tuple[str, float]]) -> None:
"""Generate streaming response using retrieved context"""
if not context:
print('No relevant context found')
return
# Build instruction with relevance scores
context_str = '\n'.join([
f'- [Relevance: {sim:.2f}] {chunk}'
for chunk, sim in context
])
instruction = f'''You are a helpful assistant. Answer the question using ONLY the provided context.
If the context doesn't contain relevant information, say so clearly.
Context:
{context_str}
Question: {query}'''
try:
stream = ollama.chat(
model=LANGUAGE_MODEL,
messages=[{'role': 'user', 'content': instruction}],
stream=True,
)
print('\nChatbot response:')
for chunk in stream:
print(chunk['message']['content'], end='', flush=True)
print('\n')
except Exception as e:
print(f'Error generating response: {e}')
def main():
"""Main RAG pipeline"""
# Load dataset
dataset = load_dataset(DATASET_PATH)
if not dataset:
return
# Build or load database
db = build_database(dataset, use_cache=True)
# Interactive query loop
print('\n=== RAG System Ready ===')
print('Type "quit" to exit\n')
while True:
query = input('Ask a question: ').strip()
if query.lower() in ['quit', 'exit', 'q']:
break
if not query:
continue
# Retrieve relevant chunks
results = retrieve(db, query, top_n=3)
# Show retrieved context
print('\nRetrieved context:')
for chunk, sim in results:
print(f' [{sim:.3f}] {chunk[:100]}...' if len(chunk) > 100 else f' [{sim:.3f}] {chunk}')
# Generate response
generate_response(query, results)
if __name__ == '__main__':
main()