We will dissect this code in a step-by-step manner to ensure it is easily understandable for beginners. The code is designed to transform natural language queries into SQL through a technique known as Retrieval-Augmented Generation (RAG). The process involves retrieving pertinent information (such as database tables) according to the user's input, which is then fed into an AI model for SQL query generation. Below is a comprehensive breakdown of each component:
Importing Libraries
import openai
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from typing import List, Tuple, Set
import logging
from dataclasses import dataclass
from enum import Enum
openai: This library lets us use OpenAI's language models (e.g., GPT-3.5) to generate SQL queries from user input.
sentence_transformers: Helps convert sentences into embeddings (numerical vectors) that capture the meaning of text, making it easier to compare.
faiss: A library for vector database that efficiently does similarity search to find similar items quickly.
numpy: A library for handling arrays and mathematical operations.
typing: Allows us to specify data types for function inputs and outputs.
logging: Enables us to print messages in a structured way, useful for debugging.
dataclasses and enum: Help define data structures and constant values in a cleaner, more organized way.
Enable Logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
logging.basicConfig: Sets up how log messages should look and their severity level.
logger: Lets us easily print structured messages, which is helpful to see what’s happening while the code runs.
Defining TableType and TableSchema Classes
These classes help us structure our information about database tables.
class TableType(Enum):
TRANSACTION = "transaction"
DIMENSION = "dimension"
LOOKUP = "lookup"
TableType: Defines different types of tables. Enum is used to represent these types as fixed categories (like TRANSACTION, DIMENSION, etc.).
@dataclass(frozen=True, eq=True)
class TableSchema:
name: str
description: str
schema: str
table_type: TableType
TableSchema: Stores information about each database table, including the table’s name, description, schema, and type (from TableType). The @dataclass decorator simplifies the code and makes it immutable (frozen=True) so its values cannot be changed.
4. The RAGSQLGenerator Class
This is the main class that performs the conversion from natural language to SQL.
Initialization (__init__Â Method)
class RAGSQLGenerator:
def __init__(self, openai_api_key: str):
self.openai_api_key = openai_api_key
openai.api_key = openai_api_key
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
logger.info("Successfully initialized SentenceTransformer model")
# Initialize context data and FAISS index
self._initialize_context_data()
self._initialize_faiss_index()
self.openai_api_key: Stores the OpenAI API key, allowing us to access OpenAI’s language models.
self.embedding_model: Loads a SentenceTransformer model that can create vector representations of sentences. These vectors make it easier to compare the meaning of different sentences.
Logging: Logs that the model is successfully loaded.
initializecontext_data and initializefaiss_index: Calls two helper functions that set up the table information and FAISS index (explained next).
Initializing Context Data (_initialize_context_data)
def _initialize_context_data(self) -> None:
self.context_data = [
TableSchema(
name="sales",
description="Sales table contains records of sales transactions with timestamp and amounts",
schema="sales(sale_id INT, product_id INT, sale_date DATETIME, amount DECIMAL(10,2))",
table_type=TableType.TRANSACTION
),
TableSchema(
name="products",
description="Products table contains product information and their categories",
schema="products(product_id INT, category_id INT, product_name VARCHAR(255), price DECIMAL(10,2))",
table_type=TableType.DIMENSION
),
TableSchema(
name="categories",
description="Categories table contains category hierarchies and descriptions",
schema="categories(category_id INT, category_name VARCHAR(255), parent_category_id INT)",
table_type=TableType.LOOKUP
)
]
self.context_data: Defines information about the database tables as a list of TableSchema objects. Each entry describes a table (name, description, schema, and type).
Initializing FAISS Index (_initialize_faiss_index)
def _initialize_faiss_index(self) -> None:
context_embeddings = np.array([
self.embedding_model.encode(f"{item.description} {item.schema}")
for item in self.context_data
]).astype('float32')
self.index_dimension = context_embeddings.shape[1]
self.index = faiss.IndexFlatL2(self.index_dimension)
self.index.add(context_embeddings)
logger.info("Successfully initialized FAISS index")
Embedding Creation: Converts each table’s description and schema to an embedding.
FAISS Index Setup: Creates a FAISS index, adds embeddings to it, and makes it ready for similarity searches.
5. Expanding Context with Related Tables
def _expand_context_with_related_tables(self, initial_contexts: List[TableSchema]) -> List[TableSchema]:
expanded_contexts: Set[TableSchema] = set(initial_contexts)
relationships = {
"sales": ["products"],
"products": ["categories"],
"categories": ["products"]
}
for context in list(expanded_contexts):
if context.name in relationships:
for related_table in relationships[context.name]:
for table_schema in self.context_data:
if table_schema.name == related_table:
expanded_contexts.add(table_schema)
return list(expanded_contexts)
expanded_contexts: Keeps track of tables related to the main tables identified.
Relationships: Defines which tables are related (e.g., sales is related to products).
Looping: Adds related tables to expanded_contexts and returns the final list.
6. Retrieving Relevant Context (retrieve_context)
def retrieve_context(self, user_input: str, top_k: int = 2) -> List[TableSchema]:
user_embedding = self.embedding_model.encode(user_input).reshape(1, -1).astype('float32')
distances, indices = self.index.search(user_embedding, top_k)
relevant_contexts = [self.context_data[i] for i in indices[0]]
return self._expand_context_with_related_tables(relevant_contexts)
user_embedding: Converts the user input to an embedding for comparison.
FAISS Search: Finds the top k (default 2) most relevant tables based on the input.
Related Tables: Expands context with related tables and returns the final list.
7. Generating SQL Query (generate_sql_query)
def generate_sql_query(self, user_input: str, schema_info: List[TableSchema]) -> str:
schema_details = "\n".join([
f"Table {schema.name} ({schema.table_type.value}):\n"
f"Description: {schema.description}\n"
f"Schema: {schema.schema}\n"
for schema in schema_info
])
prompt = f"""
Given the following database schema:
{schema_details}
Generate a SQL query for this request: {user_input}
Requirements:
- Use appropriate JOIN syntax, WHERE clauses, and formatting.
- Include GROUP BY and HAVING clauses if necessary.
- Use table aliases and descriptive column names.
- Include comments for complex logic.
SQL Query:
"""
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a SQL expert that generates efficient queries."},
{"role": "user", "content": prompt}
],
temperature=0,
max_tokens=500
)
sql_query = response.choices[0].message.content.strip()
logger.info("Successfully generated SQL query")
return sql_query
schema_details: Creates a summary of each table to help OpenAI’s model understand the database.
Prompt Creation: Prepares a detailed instruction prompt for the AI model.
OpenAI API Call: Sends the prompt to OpenAI’s model to generate an SQL query based on the schema and user request.
Main Conversion Flow (natural_language_to_sql)
def natural_language_to_sql(self, user_input: str) -> Tuple[str, List[TableSchema]]:
relevant_contexts = self.retrieve_context(user_input)
sql_query = self.generate_sql_query(user_input, relevant_contexts)
return sql_query, relevant_contexts
retrieve_context: Finds relevant tables based on the user’s input.
generate_sql_query: Uses the relevant tables to generate the final SQL query.
Main Function (main)
def main():
generator = RAGSQLGenerator("YOUR_OPENAI_API_KEY") # Replace with actual API key
test_queries = [
"Find the total sales for each product category in the last month",
"What are the top 5 selling products by revenue?",
"Show me the monthly sales trend for each category"
]
for query in test_queries:
print(f"\nProcessing query: {query}")
sql_query, contexts = generator.natural_language_to_sql(query)
print("\nRelevant Tables:")
for ctx in contexts:
print(f"- {ctx.name}: {ctx.description}")
print("\nGenerated SQL Query:")
print(sql_query)
print("\n" + "=" * 50)
if __name__ == "__main__":
main()
Main Function: Creates an instance of RAGSQLGenerator, tests various queries, and prints the relevant tables and SQL queries for each.
This code handles natural language conversion to SQL by retrieving relevant schema and generating SQL using OpenAI’s language model.
The generated query as output is :
Generated SQL Query:
```sql
SELECT
c.category_name AS category,
YEAR(s.sale_date) AS year,
MONTH(s.sale_date) AS month,
SUM(s.amount) AS total_sales
FROM sales s
JOIN products p ON s.product_id = p.product_id
JOIN categories c ON p.category_id = c.category_id
GROUP BY c.category_name, YEAR(s.sale_date), MONTH(s.sale_date)
ORDER BY c.category_name, YEAR(s.sale_date), MONTH(s.sale_date);
Use of User, System and Dynamic Prompt
In the provided code, the system, user, and dynamic prompts work together to instruct OpenAI’s model on how to generate SQL queries based on the context and user input.
Here’s how each of these components is structured in the code:
1. System Prompt
The system prompt sets the AI's role and gives it general instructions on what it should do. It helps the model understand that it needs to act as a SQL expert.
In the code, this is found in the generate_sql_query method:
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a SQL expert that generates efficient queries."},
{"role": "user", "content": prompt}
],
temperature=0,
max_tokens=500
)
is the system prompt. It sets up the model to understand its role as a "SQL expert," providing clear context for the type of responses it should generate.
2. Dynamic Prompt
The dynamic prompt includes context about the relevant database schema and requirements for SQL generation. It is created based on the relevant tables retrieved by the FAISS index. This context can change depending on which tables are retrieved for each query.
In the code, the dynamic prompt is constructed in the generate_sql_query method:
schema_details = "\n".join([
f"Table {schema.name} ({schema.table_type.value}):\n"
f"Description: {schema.description}\n"
f"Schema: {schema.schema}\n"
for schema in schema_info
])
prompt = f"""
Given the following database schema:
{schema_details}
Generate a SQL query for this request: {user_input}
Requirements:
- Use appropriate JOIN syntax, WHERE clauses, and formatting.
- Include GROUP BY and HAVING clauses if necessary.
- Use table aliases and descriptive column names.
- Include comments for complex logic.
SQL Query:
"""
Here:
schema_details dynamically lists the table schemas relevant to the user’s query, giving the model specific details it needs to form an accurate SQL query.
The requirements (like Use appropriate JOIN syntax...) guide the model on formatting and structuring the SQL, ensuring consistency across generated queries.
The prompt variable combines all of this into one detailed, structured prompt that includes the database schema and specific requirements.
3. User Prompt
The user prompt is the natural language query provided by the user. It’s the main question or request that the user wants to convert into SQL.
In the code, the user_input variable represents the user prompt, which is fed into the generate_sql_query method as part of the dynamic prompt. Here’s where the user prompt is integrated:
Generate a SQL query for this request: {user_input}
How They Work Together
System Prompt: Sets up the model as a SQL expert.
Dynamic Prompt: Provides the necessary context (e.g., database schema) and requirements for SQL formatting.
User Prompt: Provides the specific user query (natural language request) that needs to be converted into SQL.
Together, these prompts form a structured prompt for the model to process, combining expertise, context, and the user’s question to generate an accurate SQL query.
Best Practices in Managing and Governing Data
By incorporating best practices for data quality and data management, the results of the code can be greatly improved through increased accuracy in retrieval, relevance of queries, and overall robustness. Below are examples of how to seamlessly integrate these practices.
1. Metadata Management for Table Schemas, and Definitions
High-quality, standardized data meaning, and descriptions for schemas, tables and columns helps the retrieval process
Practice: Ensure accurate, complete, and standardized descriptions for each table and column.
Best practice: Use ISO/IEC 11179, an international standard that provides recommendations and requirements for defining data and metadata.
Example: Use consistent phrasing for similar table types. Instead of varying terms for sales data (e.g., "Sales transactions" vs. "Sales data"), standardize descriptions to reduce ambiguity and improve embedding similarity accuracy.
Schema Quality Check:
Verify that the table name, description, and schema align with database documentation.
Regularly review and update descriptions to reflect any changes in the database.
Other Articles: As suggested, other articles are provided for best practices. It is important to read these articles to get a better understanding of the topic. Additionally, these can provide helpful ideas and strategies for implementation.
Define a function to validate the schema information before adding it to the context:
def validate_schema(schema: TableSchema) -> bool:
# Example validation: Check if description and schema fields are filled and meet length requirements
if not schema.description or len(schema.description) < 10:
logger.warning(f"Table {schema.name} has an incomplete description.")
return False
if not schema.schema or len(schema.schema) < 10:
logger.warning(f"Table {schema.name} has an incomplete schema.")
return False
return True
2. Consistency, Data Currency and Version Control
Consistency with format changes, and schema drifts ensures that the data reflects the latest schema definitions and is compatible across different environments.
Practice: Implement version control for table schema definitions, so that model results align with the correct database version.
Example: Store versions of the database schema as JSON files or database documentation. Track any schema updates and ensure these changes are reflected in the embeddings.
Code Enhancement
Use a JSON configuration file to load schema information with version tracking:
import json
def load_schema_from_file(version: str) -> List[TableSchema]:
# Load schema details from a versioned JSON file
with open(f"schema_v{version}.json", "r") as file:
schema_data = json.load(file)
return [TableSchema(**schema) for schema in schema_data]
Now, you can dynamically load schemas based on the required version:
self.context_data = load_schema_from_file("1.0") # Use versioned schema
3. Metadata Enrichment for Table Relationships
Enhancing metadata with detailed relationships between tables (e.g., foreign keys, primary keys) improves context expansion and query relevance.
Practice: Define and store relationships between tables (e.g., sales has a foreign key to products) in the metadata.
Example: Extend each TableSchema with related tables or foreign keys. For instance, adding "related_tables": ["products", "categories"] to sales improves the RAG model’s ability to select appropriate joins.
Best Practices: Articles that provide further information on the relationships -
Code Enhancement
Add a related_tables field to TableSchema to expand the context dynamically:
@dataclass(frozen=True, eq=True)
class TableSchema:
name: str
description: str
schema: str
table_type: TableType
related_tables: List[str] = None # New field to store related tables
4. Regular Embedding Maintenance and Update Process
Regularly updating embeddings to reflect schema changes and re-training the embedding model can prevent outdated or inaccurate retrievals.
Practice: Schedule regular re-embedding and index rebuilding to ensure FAISS reflects the latest table schema definitions.
Example: If a new column or table is added, regenerate the relevant embeddings and re-index in FAISS.
Code Enhancement
Implement a re-embedding function to update FAISS after schema changes:
def update_embeddings():
# Re-generate embeddings for updated context_data and rebuild FAISS index
context_embeddings = np.array([
self.embedding_model.encode(f"{item.description} {item.schema}")
for item in self.context_data
]).astype('float32')
self.index = faiss.IndexFlatL2(context_embeddings.shape[1])
self.index.add(context_embeddings)
logger.info("FAISS index updated with new embeddings")
5. Quality Control on User Input Processing
Managing the quality of user inputs helps ensure that queries are interpreted accurately.
Practice: Use natural language preprocessing to clean and standardize user inputs.
Example: Preprocess inputs by removing unnecessary characters, standardizing terms (e.g., synonyms), and validating for completeness.
Code Enhancement
Define a preprocessing function for user inputs:
import re
def preprocess_user_input(user_input: str) -> str:
"""Clean and standardize user input to improve query consistency and quality."""
# 1. Remove unnecessary punctuation (e.g., periods, commas, special characters)
user_input = re.sub(r"[^\w\s]", "", user_input)
# 2. Convert to lowercase to maintain consistency
user_input = user_input.lower()
# 3. Remove extra whitespace
user_input = re.sub(r"\s+", " ", user_input).strip()
# 4. Standardize common phrases and terms (e.g., "last month" -> "previous month")
replacements = {
"last month": "previous month",
"previous year": "last year",
"year to date": "ytd",
"top five": "top 5",
"most recent": "latest"
}
for phrase, replacement in replacements.items():
user_input = re.sub(rf"\b{phrase}\b", replacement, user_input)
# 5. Handle numbers in words to numerals (e.g., "ten" -> "10")
word_to_number = {
"zero": "0", "one": "1", "two": "2", "three": "3", "four": "4",
"five": "5", "six": "6", "seven": "7", "eight": "8", "nine": "9",
"ten": "10"
}
for word, number in word_to_number.items():
user_input = re.sub(rf"\b{word}\b", number, user_input)
# 6. Normalize date expressions (e.g., "today", "yesterday", "next month")
date_replacements = {
"today": "current day",
"yesterday": "previous day",
"tomorrow": "next day",
"next month": "upcoming month"
}
for phrase, replacement in date_replacements.items():
user_input = re.sub(rf"\b{phrase}\b", replacement, user_input)
# 7. Standardize common SQL keywords/phrases for consistency
sql_keywords = {
"show me": "select",
"list": "select",
"give me": "select",
"find": "select"
}
for phrase, replacement in sql_keywords.items():
user_input = re.sub(rf"\b{phrase}\b", replacement, user_input)
# 8. Replace abbreviations with full forms, if applicable
abbreviations = {
"dept": "department",
"emp": "employee",
"mgr": "manager",
"cust": "customer",
"amt": "amount"
}
for abbreviation, full_form in abbreviations.items():
user_input = re.sub(rf"\b{abbreviation}\b", full_form, user_input)
return user_input
Call preprocess_user_input before passing user_input to retrieve_context:
user_input = preprocess_user_input(user_input)
6. Performance Monitoring and Logging of SQL Generation
Regular monitoring helps identify bottlenecks and improve performance over time.
Practice: Track response times and accuracy metrics for retrieval and generation steps.
Example: Measure the time taken by FAISS retrieval and OpenAI API calls, logging any unusually long responses.
Code Enhancement
import time
def retrieve_context(self, user_input: str, top_k: int = 2) -> List[TableSchema]:
start_time = time.time()
# Retrieval code
logger.info(f"Context retrieval time: {time.time() - start_time} seconds")
return expanded_contexts
This can help identify and optimize parts of the workflow that may be slowing down.
7. Implementing Access Control and Security Measures
Protecting sensitive data and managing access control prevents unauthorized access and usage of the data.
Practice: Store sensitive information (e.g., API keys) in secure, encrypted storage (e.g., environment variables).
Example: Load API keys from environment variables instead of hardcoding them in the code.
Code Enhancement
Access the OpenAI API key securely from an environment variable:
import os
openai.api_key = os.getenv("OPENAI_API_KEY")
Set the environment variable in your terminal or IDE to avoid exposing the key in your code.
8. Documentation and Training Dataset Transparency
Detailed documentation of schema fields and query examples helps users understand the data structure and usage, leading to better input queries.
Practice: Document each schema field and store common query examples to train the model on realistic, quality examples.
Example: Store a JSON or markdown file with schema explanations and sample queries to help guide input expectations.
Code Enhancement
Create a schema_docs.json file with sample queries and load it for reference in the code or as a help guide:
{
"sales": {
"description": "Sales table contains records of sales transactions",
"sample_queries": [
"Find total sales by month",
"List top 5 products by sales volume"
]
},
...
}
Load the documentation in code to serve as a reference:
def load_sample_queries(schema_name: str) -> List[str]:
with open("schema_docs.json", "r") as f:
docs = json.load(f)
return docs.get(schema_name, {}).get("sample_queries", [])
Summary of Best Practices Applied
Applying these data quality and data management practices can improve the effectiveness of the retrieval and generation process in RAG by ensuring:
Metadata Management for Table Schemas, and Definitions of table schemas and descriptions.
Consistency, Data Currency and Version Control to keep schemas up to date.
Regular Embedding Maintenance and Update Process to guide query generation better.
Timely re-indexing to reflect schema changes and drifts.
Quality Control on User Input Processing for cleaner user queries.
Performance Monitoring and Logging of SQL Generation to maintain efficiency.
Implementing Access Control and Security Measures by managing access to sensitive data.
Documentation and Training Dataset Transparency via documentation to guide user expectations.
These practical steps align data quality and management practices with the technical flow, resulting in a more reliable, accurate, and user-friendly RAG system.
Comments