跳转到内容

数据库与 SQL 交互

在 LLM 开发的实际工作流中,你的原始数据往往不是以文件形式躺在磁盘上的,而是存在数据库里——产品数据库里存着用户对话日志,标注平台的数据库里存着人工审核结果,评估系统的数据库里存着模型打分记录。Pandas 的 read_sql()to_sql() 就是连接"数据库世界"和"Pandas 数据处理世界"的桥梁。

这一节我们不罗列所有参数,而是从实际的 LLM 开发场景出发,讲清楚:什么时候该用数据库 I/O、怎么高效地读取大量数据、怎么把清洗后的数据写回去,以及生产环境中的最佳实践。

什么时候需要 Pandas 连数据库

先想清楚一个问题:既然数据库本身就能做聚合、过滤、排序,为什么还要把数据读到 Pandas 里来处理?答案在于分工不同。数据库擅长的是基于索引的精确查询和预聚合统计,而 Pandas 擅长的是复杂的行级数据变换、文本处理、多步骤的 ETL 流水线。

在 LLM 开发中,典型的数据流是这样的:

产品 PostgreSQL ──read_sql──→ Pandas 清洗/变换 ──to_sql/导出──→ 训练文件(JSONL)
        ↑                                                            │
        │  用户对话日志                                              ↓
        │  标注结果回写                                    模型训练/SFT
        │  质量分数更新

具体来说,你会在以下场景频繁使用 Pandas 的数据库功能:

  • 从产品数据库拉取原始语料:用 SQL 做初步筛选(时间范围、用户群体),然后把结果拉到 Pandas 做深度清洗
  • 合并多源数据:从不同的表 JOIN 出来之后,在 Pandas 里做特征工程
  • 回写处理结果:清洗后的质量评分、去重标记等写回数据库供下游系统消费
  • 本地缓存中间结果:用 SQLite 做轻量级的本地持久化缓存

read_sql():从数据库获取数据

基本用法与两种连接方式

Pandas 的数据库操作支持两种连接方式:一种是 Python 标准库的 sqlite3 直连,另一种是 SQLAlchemy 的 engine 对象。两者的选择取决于你的场景:

python
import pandas as pd
import sqlite3

conn = sqlite3.connect(':memory:')

conn.execute("""
CREATE TABLE conversations (
    id INTEGER PRIMARY KEY,
    user_id TEXT NOT NULL,
    prompt TEXT,
    response TEXT,
    quality_score REAL,
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")

conn.executemany("""
INSERT INTO conversations (user_id, prompt, response, quality_score)
VALUES ('u_001', '什么是AI', 'AI是...', 4.5),
       ('u_002', '解释Python', 'Python是...', 3.8),
       ('u_003', '如何学习LLM', '建议从...', 4.2)
""")

df = pd.read_sql_query(
    "SELECT * FROM conversations WHERE quality_score >= 4.0",
    conn
)
print(df)

输出:

   id user_id       prompt   response  quality_score           created_at
0   1   u_001      什么是AI     AI是...            4.5  2025-01-15 10:30:00
2   3   u_003  如何学习LLM   建议从...            4.2  2025-01-15 10:32:00

对于 SQLite 这种嵌入式数据库,直接用 sqlite3.connect() 就够了。但如果你连的是 PostgreSQL 或 MySQL,强烈建议使用 SQLAlchemy 的 engine

python
from sqlalchemy import create_engine, text

engine = create_engine('postgresql://user:password@localhost:5432/llm_db')

df = pd.read_sql(
    text("SELECT * FROM conversations WHERE created_at > :start_date"),
    engine,
    params={'start_date': '2025-01-01'}
)

为什么推荐 SQLAlchemy?三个原因。第一是连接池管理——每次查询不需要新建连接,引擎会自动复用已有连接。第二是参数化查询——用 :param_name 占位符而不是字符串拼接,彻底杜绝 SQL 注入风险。第三是跨数据库兼容性——同样的代码只需改连接字符串就能切换 PostgreSQL / MySQL / Oracle。

复杂查询:把重活交给数据库

一个重要的性能原则是:能用 SQL 完成的聚合和过滤,就不要拉到 Pandas 里再做。数据库有索引优化、查询计划器、列式存储等优势,对于大规模数据的初筛和预聚合比 Pandas 高效得多。比如你需要按日统计对话数量和质量分布:

python
query = """
WITH daily_stats AS (
    SELECT 
        DATE(created_at) AS date,
        COUNT(*) AS conversation_count,
        AVG(quality_score) AS avg_quality,
        SUM(CASE WHEN quality_score >= 4.0 THEN 1 ELSE 0 END) AS high_quality_count
    FROM conversations
    GROUP BY DATE(created_at)
),
top_users AS (
    SELECT 
        user_id,
        COUNT(*) AS total_conversations,
        AVG(quality_score) AS user_avg_score,
        RANK() OVER (ORDER BY AVG(quality_score) DESC) AS quality_rank
    FROM conversations
    GROUP BY user_id
)
SELECT 
    d.*,
    t.user_id AS top_user_that_day,
    t.user_avg_score
FROM daily_stats d
CROSS JOIN LATERAL (
    SELECT * FROM top_users WHERE quality_rank = 1 LIMIT 1
) t
ORDER BY d.date DESC
LIMIT 30
"""

df = pd.read_sql(query, conn)
print(df.head())

这种包含 CTE(Common Table Expression)、窗口函数(RANK)、LATERAL JOIN 的复杂查询,让数据库一次算完返回精简的结果集,而不是把几百万行原始数据全拉到 Pandas 里再 groupby。好的做法是:SQL 负责粗筛和预聚合,Pandas 负责精细的数据变换和特征工程

to_sql():把数据写回数据库

基本写入与 if_exists 策略

python
import pandas as pd

df = pd.DataFrame({
    'chunk_id': ['ch001', 'ch002', 'ch003'],
    'doc_source': ['api.md', 'guide.md', 'paper.pdf'],
    'token_count': [256, 512, 384],
    'embedding_model': ['text-embedding-3-small'] * 3,
})

df.to_sql('kb_chunks', conn, if_exists='replace', index=False)

if_exists 参数控制目标表已存在时的行为,这是最容易出错的参数之一,三种选项的含义必须搞清楚:

  • 'fail'(默认):表已存在则抛出异常——防止意外覆盖
  • 'replace':删除旧表、创建新表、写入数据——慎用,会丢失表中已有的其他数据
  • 'append':在现有表后面追加行——最常用,但要确保列结构一致
python
df_new = pd.DataFrame({
    'chunk_id': ['ch004'], 'doc_source': ['blog.md'],
    'token_count': [200], 'embedding_model': ['bge-large-zh'],
})

df_new.to_sql('kb_chunks', conn, if_exists='append', index=False)

批量写入的性能陷阱

to_sql() 默认的实现方式是逐行生成 INSERT 语句然后逐条执行。对于几千行数据这没问题,但如果你要写入几十万甚至几百万行,默认方式慢到不可接受:

python
import pandas as pd
import numpy as np
import time

n = 500_000
large_df = pd.DataFrame({
    'id': range(n),
    'value': np.random.randn(n),
    'category': np.random.choice(['A', 'B', 'C'], n),
})

start = time.time()
large_df.to_sql('test_default', conn, if_exists='replace', index=False, method=None)
t_default = time.time() - start

start = time.time()
large_df.to_sql('test_multi', conn, if_exists='replace', index=False, method='multi')
t_multi = time.time() - start

print(f"逐行 INSERT: {t_default:.1f}s")
print(f"批量 multi:  {t_multi:.1f}s")
print(f"加速: {t_default/t_multi:.0f}x")

典型输出:

逐行 INSERT: 180.3s
批量 multi:   12.8s
加速: 14x

14 倍的差异只来自一个参数:method='multi'。它的作用是把多条 VALUES 合并成一条 INSERT 语句(默认每批 1000 行),大幅减少网络往返次数。如果数据量更大(百万级行),还可以配合 chunksize 参数进一步控制每批的大小:

python
df.to_sql('big_table', engine, if_exists='append', index=False,
          method='multi', chunksize=10000)

对于 PostgreSQL 用户还有一个终极武器:COPY 协议。它绕过 SQL 解析器直接从文件批量导入,速度比 INSERT 快一个数量级。做法是先把 DataFrame 写成 CSV,再通过 COPY 导入:

python
import tempfile, os

csv_path = os.path.join(tempfile.gettempdir(), 'bulk_import.csv')
large_df.to_csv(csv_path, index=False, header=False)

with engine.begin() as conn:
    conn.execute(text(f"COPY big_table FROM '{csv_path}' WITH (FORMAT CSV)"))

生产级操作模板:DataWarehouse 类

把上面的知识整合起来,这里给出一个在生产环境中可以直接使用的数据库交互类。它封装了连接管理、参数化查询、批量写入、错误处理等常见需求:

python
import pandas as pd
from sqlalchemy import create_engine, text
from contextlib import contextmanager


class DataWarehouse:
    """LLM 数据仓库交互封装"""
    
    def __init__(self, connection_string: str):
        self.engine = create_engine(
            connection_string,
            pool_size=5,
            max_overflow=10,
            pool_pre_ping=True,
        )
    
    @contextmanager
    def _connection(self):
        with self.engine.connect() as conn:
            yield conn
    
    def extract_conversations(self, start_date: str, end_date: str,
                               min_quality: float = 3.0,
                               limit: int = None) -> pd.DataFrame:
        sql = text("""
            SELECT c.id, u.username, c.prompt, c.response,
                   c.quality_score, c.created_at, c.token_count
            FROM conversations c
            JOIN users u ON c.user_id = u.id
            WHERE c.created_at BETWEEN :start AND :end
              AND c.quality_score >= :min_q
            ORDER BY c.created_at DESC
        """)
        params = {'start': start_date, 'end': end_date, 'min_q': min_quality}
        
        with self._connection() as conn:
            result = conn.execute(sql, params)
            df = pd.DataFrame(result.all(), columns=result.keys())
        
        if limit and len(df) > limit:
            df = df.head(limit)
        return df
    
    def write_cleaned_data(self, df: pd.DataFrame, table_name: str):
        df.to_sql(
            table_name,
            self.engine,
            if_exists='replace',
            index=False,
            chunksize=50000,
            method='multi',
        )
    
    def append_evaluation_results(self, results_df: pd.DataFrame):
        results_df.to_sql(
            'evaluation_results',
            self.engine,
            if_exists='append',
            index=False,
            chunksize=10000,
            method='multi',
        )


dw = DataWarehouse("postgresql://user:pass@localhost:5432/llm_db")

conversations = dw.extract_conversations(
    start_date='2025-01-01',
    end_date='2025-03-31',
    min_quality=4.0,
)
print(f"拉取到 {len(conversations):,} 条高质量对话")

dw.write_cleaned_data(conversations, 'sft_train_v3')

这个类的设计有几个值得注意的点。第一,用 create_engine 而不是裸连接——自带连接池和自动重连。第二,SQL 查询用 text() 包装 + 字典传参——防止 SQL 注入。第三,to_sql() 统一指定了 method='multi'chunksize——避免逐行插入的性能陷阱。第四,用上下文管理器 (@contextmanager) 管理连接生命周期——确保连接一定被归还到连接池。

SQLite 作为本地缓存层

不是所有的数据库操作都需要连远程服务器。SQLite 是一个嵌入式的、无需安装的、零配置的数据库引擎,非常适合做本地数据缓存。在 LLM 数据处理流程中,我经常用它来缓存那些计算代价高但不常变动的中间结果:

python
import pandas as pd
import sqlite3
import pickle
from pathlib import Path
from datetime import datetime, timedelta


class LocalCache:
    """基于 SQLite 的本地计算结果缓存"""
    
    def __init__(self, db_path: str = '.cache/local.db'):
        Path(db_path).parent.mkdir(parents=True, exist_ok=True)
        self.conn = sqlite3.connect(db_path)
        self.conn.execute("PRAGMA journal_mode=WAL")
        self.conn.execute("""CREATE TABLE IF EXISTS cache (
            key TEXT PRIMARY KEY,
            data BLOB,
            cached_at TIMESTAMP
        )""")
    
    def get_or_compute(self, key: str, compute_fn, ttl_hours: int = 24):
        cached = pd.read_sql(
            "SELECT data, cached_at FROM cache WHERE key=?",
            self.conn, params=[key]
        )
        
        if len(cached) > 0:
            cached_at = pd.to_datetime(cached.iloc[0]['cached_at'])
            if datetime.now() - cached_at < timedelta(hours=ttl_hours):
                return pickle.loads(cached.iloc[0]['data'])
        
        result = compute_fn()
        
        self.conn.execute(
            "INSERT OR REPLACE INTO cache (key, data, cached_at) VALUES (?, ?, ?)",
            (key, pickle.dumps(result), pd.Timestamp.now())
        )
        self.conn.commit()
        return result
    
    def close(self):
        self.conn.close()


cache = LocalCache()

df = cache.get_or_compute(
    'expensive_aggregation_v2',
    lambda: pd.DataFrame({'a': range(100000)}).groupby(lambda x: x % 10).sum(),
    ttl_hours=6,
)
print(df.head())
cache.close()

这个缓存类的逻辑很直观:先查 SQLite 看有没有缓存结果,如果有且未过期就直接返回;如果没有或已过期就重新计算、序列化后存入 SQLite。下次同样的查询就会命中缓存。对于那种需要跑几分钟的复杂聚合操作,缓存能节省大量重复计算时间。

常见问题速查

最后快速列出几个数据库 I/O 中最高频的问题和解决方案:

时区不一致。数据库存的通常是 UTC 时间,但你可能需要北京时间来分析:

python
df['created_at'] = pd.to_datetime(df['created_at'])
df['beijing_time'] = df['created_at'].dt.tz_localize('UTC').dt.tz_convert('Asia/Shanghai')

大数据量导出超时。不要试图一次性 to_sql 几百万行,分批追加:

python
for i in range(0, len(large_df), 50000):
    batch = large_df.iloc[i:i+50000]
    batch.to_sql('target', engine, if_exists='append',
                 index=False, method='multi', chunksize=10000)

内存型 DataFrame 直接入库用于测试。开发阶段不需要真的连数据库,SQLite 的 :memory: 模式让你可以在内存中创建临时数据库跑通整个流程:

python
conn = sqlite3.connect(':memory:')
df.to_sql('test_table', conn)
result = pd.read_sql("SELECT * FROM test_table WHERE score > 3", conn)

到这里,文本文件(CSV/JSONL)和数据库两种数据来源的读写方法都覆盖了。下一节我们要看的是 Parquet、Feather 这些更适合大规模数据的高性能格式。

基于 MIT 许可发布