概述

数据处理和分析是从原始数据中提取有价值信息的过程它是数据科学机器学习和商业智能的基础现代数据分析强调自动化可重复性和可视化通过系统化的方法将杂乱无章的数据转化为清晰的洞察和决策依据

⭐⭐ 概念

数据处理对原始数据进行收集清洗转换和整合的过程确保数据质量和一致性


数据分析运用统计方法和分析工具从处理后的数据中发现模式趋势和关联关系
数据可视化通过图表图形和仪表板将分析结果以直观的方式呈现便于理解和决策
探索性数据分析EDA在正式建模前对数据进行初步探索了解数据分布异常值和相关性

⭐⭐ 特点

数据驱动决策基于客观数据而非主观判断进行业务决策提高决策的准确性和科学性


多源数据整合能够处理来自数据库API文件传感器日志等多种数据源的异构数据
实时性要求支持批处理和流式处理两种模式满足不同场景下的实时分析需求
可扩展性从小规模数据集到 TB/PB 级大数据集群都能高效处理具备良好的水平扩展能力
可视化呈现通过丰富的图表类型和交互式仪表板直观展示分析结果降低理解门槛
自动化流程支持数据管道Pipeline自动化实现从数据采集到报告生成的全流程自动化
迭代优化分析过程可重复可验证便于持续改进模型和方法形成良性循环
跨领域应用适用于金融医疗电商制造教育等各个行业具有广泛的适用性
技术栈丰富拥有 NumPyPandasMatplotlibSeaborn 等成熟的开源工具和库生态系统

⭐⭐ 工作流程

数据收集从各种来源获取原始数据包括数据库查询API 调用文件导入网络爬虫等


数据清洗处理缺失值异常值重复值统一数据格式确保数据质量
数据转换进行特征工程标准化编码转换等操作使数据适合分析需求
探索分析运用描述性统计可视化等方法初步了解数据特征和分布规律
深入分析应用统计分析机器学习时间序列等方法挖掘数据价值
结果可视化通过图表报告仪表板等形式展示分析结果和洞察
决策支持基于分析结果为业务决策提供数据支持和建议

Anaconda

Anaconda 是一个开源的 Python 和 R 语言数据科学平台包含包管理环境管理和众多预装的数据科学工具它是数据科学家和机器学习工程师的首选开发环境

  • 核心功能
    • 包管理Conda 包管理器支持 PythonR 等多语言
    • 环境管理创建隔离的虚拟环境避免依赖冲突
    • 预装工具Jupyter NotebookSpyder 等 IDE
    • 数据科学库NumPyPandasMatplotlibScikit-learn 等
  • 主要优势
    • 开箱即用预装 1500+ 数据科学包无需单独安装
    • 跨平台支持 WindowsmacOSLinux
    • 环境隔离轻松管理多个项目的环境和依赖
    • 社区活跃强大的社区支持和丰富的文档资源

核心组件

  • Conda
    • 包管理和环境管理工具
    • 支持多语言PythonRRuby 等
    • 自动解决依赖关系
  • Anaconda Navigator
    • 图形化界面管理工具
    • 可视化管理环境包和应用
  • Jupyter Notebook
    • 交互式编程环境
    • 支持代码文本图表混合展示
  • Spyder
    • 专为数据科学设计的 IDE
    • 集成编辑器控制台和变量浏览器

版本选择

  • Anaconda
    • 完整版包含 1500+ 数据科学包
    • 安装包较大约 3GB
    • 适合初学者和需要完整功能的用户
  • Miniconda
    • 精简版只包含 Conda 和 Python
    • 安装包较小约 400MB
    • 适合高级用户和需要自定义环境的场景

工作原理

环境管理机制

1
2
3
4
5
6
7
8
9
10
环境隔离原理
1. 每个环境有独立的 Python 解释器
2. 每个环境有独立的包安装目录
3. 环境之间互不影响
4. 通过激活/切换环境来使用不同配置

优势
- 避免项目间依赖冲突
- 便于复现开发环境
- 支持多版本 Python 共存

包管理流程

1
2
3
4
5
6
7
8
9
包安装流程
搜索包 → 解析依赖 → 下载包 → 安装包 → 更新元数据

关键步骤
- 搜索在 Conda 仓库中查找包
- 解析分析包的依赖关系
- 下载从镜像源下载包文件
- 安装解压并配置包
- 更新记录已安装的包信息

安装配置

安装版本建议

1
2
# 如果不需要完整 Anaconda使用轻量级 Miniconda
# 下载 Minicondahttps://docs.conda.io/en/latest/miniconda.html

Windows 安装

1
2
3
4
5
6
7
8
9
# 方法一图形化安装
1. 下载安装包https://www.anaconda.com/products/distribution
2. 双击 .exe 文件运行安装程序
3. 选择安装路径建议默认
4. 勾选"Add Anaconda to PATH"可选
5. 完成安装

# 方法二命令行安装使用 Chocolatey
choco install anaconda3

macOS 安装

1
2
3
4
5
6
7
8
9
10
# 方法一图形化安装
1. 下载 .pkg 安装包
2. 双击运行安装程序
3. 按照提示完成安装

# 方法二使用 Homebrew
brew install --cask anaconda

# 方法三命令行安装
bash Anaconda3-2023.09-0-MacOSX-x86_64.sh

Linux 安装

1
2
3
4
5
6
7
8
9
# 下载安装脚本
wget https://repo.anaconda.com/archive/Anaconda3-2023.09-0-Linux-x86_64.sh

# 运行安装脚本
bash Anaconda3-2023.09-0-Linux-x86_64.sh

# 按照提示完成安装
# 安装完成后重新加载配置
source ~/.bashrc

验证安装

1
2
3
4
5
6
7
8
9
10
11
# 查看 Conda 版本
conda --version

# 查看 Anaconda 信息
conda info

# 查看已安装的包
conda list

# 查看 Python 版本
python --version

配置镜像源

1
2
3
4
5
6
7
8
9
10
11
12
13
# 添加清华镜像源推荐国内用户
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/

# 设置显示通道地址
conda config --set show_channel_urls yes

# 查看当前配置
conda config --show channels

# 恢复默认源
conda config --remove-key channels

常用配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 设置自动激活 base 环境默认开启
conda config --set auto_activate_base true

# 关闭自动激活
conda config --set auto_activate_base false

# 设置包缓存目录
conda config --set pkgs_dirs /path/to/pkgs

# 设置环境目录
conda config --set envs_dirs /path/to/envs

# 查看所有配置
conda config --show

环境管理

创建环境

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 创建新环境指定 Python 版本
conda create --name myenv python=3.9

# 创建环境并安装包
conda create --name myenv python=3.9 numpy pandas matplotlib

# 创建环境时指定包版本
conda create --name myenv python=3.9 numpy=1.24.3 pandas=2.0.0

# 从配置文件创建环境
conda env create -f environment.yml

# 克隆现有环境
conda create --name newenv --clone oldenv

切换环境

1
2
3
4
5
6
7
8
9
10
11
12
13
# 激活环境
conda activate myenv

# 退出当前环境
conda deactivate

# 切换到 base 环境
conda activate base

# 查看所有环境
conda env list
# 或
conda info --envs

删除环境

1
2
3
4
5
6
# 删除指定环境
conda remove --name myenv --all

# 删除当前激活的环境
conda deactivate
conda remove --name $CONDA_DEFAULT_ENV --all

导出导入

1
2
3
4
5
6
7
8
9
10
11
# 导出环境配置
conda env export > environment.yml

# 导出精简配置仅包含手动安装的包
conda env export --from-history > environment.yml

# 从配置文件创建环境
conda env create -f environment.yml

# 更新现有环境
conda env update -f environment.yml --prune

使用建议

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 1. 为每个项目创建独立环境
# 避免在 base 环境中安装包

# 2. 定期清理未使用的环境
conda env list
conda remove --name unused_env --all

# 3. 导出环境配置用于版本控制
conda env export --from-history > environment.yml

# 4. 使用 .condarc 文件管理配置
# 在项目根目录创建 .condarc
channels:
- defaults
- conda-forge

💗💗 environment.yml 示例

1
2
3
4
5
6
7
8
9
10
11
12
13
name: myenv
channels:
- defaults
- conda-forge
dependencies:
- python=3.9
- numpy=1.24.3
- pandas=2.0.0
- matplotlib=3.7.1
- scikit-learn=1.2.2
- pip:
- tensorflow==2.13.0
- keras==2.13.1

包管理

安装包

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 从 Conda 安装包
conda install numpy
conda install numpy pandas matplotlib

# 指定版本安装
conda install numpy=1.24.3

# 从 conda-forge 频道安装
conda install -c conda-forge package_name

# 使用 pip 安装包在 Conda 环境中
pip install package_name

# 批量安装从 requirements.txt
pip install -r requirements.txt

更新包

1
2
3
4
5
6
7
8
9
10
11
# 更新指定包
conda update numpy

# 更新所有包
conda update --all

# 更新 Conda 本身
conda update conda

# 更新 Anaconda 发行版
conda update anaconda

卸载包

1
2
3
4
5
6
7
8
# 卸载指定包
conda remove numpy

# 卸载包及其依赖
conda remove numpy --force

# 清理未使用的包和缓存
conda clean --all

搜索包

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 搜索包
conda search numpy

# 搜索特定版本的包
conda search numpy=1.24

# 查看包的详细信息
conda search numpy --info

# 查看已安装的包
conda list

# 查看特定环境中的包
conda list -n myenv

# 导出已安装的包列表
conda list --export > package-list.txt

使用建议

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 1. 优先使用 Conda 安装包
conda install package_name

# 2. 必要时使用 pip
pip install package_name

# 3. 避免混用 Conda 和 pip 安装同一包
# 如果必须混用先安装所有 Conda 包再安装 pip 包

# 4. 定期更新包
conda update --all

# 5. 记录包版本以便复现
conda list --export > packages.txt

通道管理

管理通道

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 查看当前通道
conda config --show channels

# 添加通道
conda config --add channels conda-forge
conda config --add channels bioconda

# 设置通道优先级
conda config --set channel_priority strict

# 移除通道
conda config --remove channels conda-forge

# 恢复默认通道
conda config --remove-key channels

常用通道

通道名称 说明 适用场景
defaults Anaconda 官方默认通道 通用包
conda-forge 社区维护的通道 最新版本的包
bioconda 生物信息学包 生物数据分析
pytorch PyTorch 官方通道 深度学习

虚拟环境

命名规范

1
2
3
4
5
6
7
8
9
10
11
# 环境命名
myproject_env # 项目环境
data_science_env # 数据科学环境
ml_experiment_env # 机器学习实验环境
web_dev_env # Web 开发环境

# 避免使用特殊字符
# 好
conda create --name my_env
# 不好
conda create --name my-env # 可能在某些系统中出现问题

项目环境隔离

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 为每个项目创建独立环境
mkdir my_project
cd my_project

# 创建环境
conda create --name my_project_env python=3.9

# 激活环境
conda activate my_project_env

# 安装项目依赖
conda install numpy pandas scikit-learn
pip install tensorflow

# 导出环境配置
conda env export > environment.yml

# 将 environment.yml 加入版本控制
git add environment.yml

环境共享

1
2
3
4
5
6
7
8
# 团队成员 A导出环境
conda env export > environment.yml

# 团队成员 B创建相同环境
conda env create -f environment.yml

# 在不同机器上复现环境
conda env create -f environment.yml --name project_env

性能优化

加速包安装

1
2
3
4
5
6
7
8
9
10
11
# 使用 mamba更快的包管理器
conda install mamba -n base -c conda-forge

# 配置国内镜像源加速下载
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/

# 使用 mamba 安装包
mamba install numpy pandas

# 创建环境
mamba create --name myenv python=3.9 numpy pandas

清理缓存

1
2
3
4
5
6
7
8
# 清理未使用的包和缓存
conda clean --packages --tarballs

# 清理所有缓存
conda clean --all

# 查看缓存大小
conda clean --all --dry-run

并行下载

1
2
3
# 设置并行下载数
conda config --set remote_read_timeout_secs 600
conda config --set ssl_verify no # 仅在可信网络中使用

集成其他 IDE

VS Code 集成

1
2
3
4
5
6
7
配置步骤
1. 安装 Python 扩展
2. 选择 Conda 环境作为解释器
- Ctrl+Shift+P → "Python: Select Interpreter"
- 选择 Conda 环境
3. 在终端中激活环境
- VS Code 会自动激活选定的 Conda 环境

PyCharm 集成

1
2
3
4
5
6
配置步骤
1. File → Settings → Project → Python Interpreter
2. 点击齿轮图标 → Add
3. 选择 "Conda Environment"
4. 选择 "Existing environment" 或 "New environment"
5. 指定 Conda 可执行文件路径

Docker 集成

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 使用 Anaconda 基础镜像
FROM continuumio/anaconda3:latest

# 复制环境配置
COPY environment.yml /tmp/environment.yml

# 创建环境
RUN conda env create -f /tmp/environment.yml

# 激活环境
SHELL ["conda", "run", "-n", "myenv", "/bin/bash", "-c"]

# 运行应用
CMD ["python", "app.py"]

Spyder IDE

1
2
3
4
5
# 启动 Spyder
spyder

# 重置 Spyder 配置
spyder --reset

Jupyter

Jupyter Notebook 是一个开源的交互式计算环境支持创建和共享包含实时代码方程可视化和文本的文档它是数据科学机器教育和科研领域的标准工具

  • 核心功能
    • 交互式编程实时执行代码并查看结果
    • 富文本支持MarkdownLaTeX 公式HTML
    • 可视化集成图表图像视频嵌入
    • 多语言支持PythonRJuliaScala 等 40+ 语言
  • 主要优势
    • 易于使用基于浏览器的界面无需复杂配置
    • 可复现性代码数据和结果保存在同一文档中
    • 协作友好支持分享和版本控制
    • 生态丰富与数据科学库无缝集成

核心组件

  • Jupyter Notebook
    • 经典的笔记本界面
    • .ipynb 格式保存文档
    • 适合教学和快速原型开发
  • JupyterLab
    • 下一代 IDE 界面
    • 模块化设计支持扩展
    • 适合作为日常开发环境
  • JupyterHub
    • 多用户服务器
    • 支持团队协作和教育场景
    • 集中管理多个 Notebook 实例
  • nbconvert
    • 格式转换工具
    • 支持转换为 HTMLPDFMarkdown 等格式

文件格式

  • .ipynb (IPython Notebook)
    • JSON 格式的笔记本文件
    • 包含代码单元格输出元数据
    • 可用文本编辑器查看和编辑
  • .py (Python 脚本)
    • 可从 Notebook 导出
    • 适合部署和生产环境

工作原理

架构设计

1
2
3
4
5
6
7
8
Jupyter 架构
浏览器前端 ←→ Jupyter Server ←→ Kernel

关键组件
1. 前端Web 界面用户交互
2. 服务器处理请求管理会话
3. Kernel执行代码的内核如 IPython
4. 通信通过 WebSocket 和 REST API

执行流程

1
2
3
4
5
6
7
8
用户操作 → 前端发送请求 → 服务器转发 → Kernel 执行 → 返回结果 → 前端显示

关键步骤
- 输入用户在单元格中编写代码
- 发送前端将代码发送到服务器
- 执行Kernel 执行代码
- 返回执行结果返回给前端
- 显示前端渲染结果显示

安装配置

使用 Conda 安装

1
2
3
4
5
6
7
8
9
# 安装 Jupyter Notebook
conda install jupyter

# 安装 JupyterLab
conda install jupyterlab

# 在现有环境中安装
conda activate myenv
conda install jupyter

使用 pip 安装

1
2
3
4
5
6
7
8
# 安装 Jupyter Notebook
pip install jupyter

# 安装 JupyterLab
pip install jupyterlab

# 安装所有 Jupyter 组件
pip install jupyterlab notebook nbconvert

验证安装

1
2
3
4
5
6
7
8
9
10
11
# 查看 Jupyter 版本
jupyter --version

# 查看已安装的 kernels
jupyter kernelspec list

# 启动 Jupyter Notebook
jupyter notebook

# 启动 JupyterLab
jupyter lab

基本启动

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 启动 Jupyter Notebook
jupyter notebook

# 启动 JupyterLab
jupyter lab

# 指定端口
jupyter notebook --port 8888
jupyter lab --port 9999

# 不打开浏览器
jupyter notebook --no-browser
jupyter lab --no-browser

# 指定工作目录
jupyter notebook --notebook-dir=/path/to/dir
jupyter lab --notebook-dir=/path/to/dir

配置文件

1
2
3
4
5
6
7
8
9
# 生成 Notebook 配置
jupyter notebook --generate-config

# 生成 JupyterLab 配置
jupyter lab --generate-config

# 配置文件位置
# Windows: C:\Users\username\.jupyter\jupyter_notebook_config.py
# Linux/macOS: ~/.jupyter/jupyter_notebook_config.py

💗💗 常用配置项

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# jupyter_notebook_config.py

# 设置默认浏览器
c.NotebookApp.browser = 'chrome'

# 设置工作目录
c.NotebookApp.notebook_dir = '/path/to/workspace'

# 设置端口
c.NotebookApp.port = 8888

# 允许远程访问
c.NotebookApp.ip = '0.0.0.0'

# 禁用令牌认证仅用于本地测试
c.NotebookApp.token = ''

# 自动打开浏览器
c.NotebookApp.open_browser = True

# 设置密码保护
from notebook.auth import passwd
c.NotebookApp.password = passwd('your_password')

密码保护

1
2
3
4
5
# 生成密码哈希
python -c "from notebook.auth import passwd; print(passwd())"

# 输入密码后将生成的哈希值添加到配置文件
# c.NotebookApp.password = 'sha1:...'

Kernels 管理

基本操作

1
2
3
4
5
6
7
8
9
10
11
# 查看已安装的 kernels
jupyter kernelspec list

# 安装新的 kernel
python -m ipykernel install --user --name myenv --display-name "Python (myenv)"

# 卸载 kernel
jupyter kernelspec uninstall myenv

# 查看 kernel 信息
jupyter kernelspec describe python3

多语言支持

1
2
3
4
5
6
7
8
9
10
11
12
# 安装 R kernel
conda install -c r r-irkernel

# 安装 Julia kernel
using Pkg
Pkg.add("IJulia")

# 安装 Scala kernel
# 参考 Almond 项目https://almond.sh/

# 查看可用的 kernels
jupyter kernelspec list

快捷键

命令模式

按 Esc 进入命令模式

快捷键 功能
A 在上方插入单元格
B 在下方插入单元格
D, D 删除单元格按两次 D
M 转换为 Markdown 单元格
Y 转换为代码单元格
R 转换为 Raw 单元格
Shift + ↑/↓ 选择多个单元格
Ctrl + Shift + - 分割单元格
Z 撤销删除

编辑模式

按 Enter 进入编辑模式

快捷键 功能
Shift + Enter 运行单元格并选中下一个
Ctrl + Enter 运行单元格
Alt + Enter 运行单元格并在下方插入新单元格
Tab 代码补全或缩进
Shift + Tab 显示函数文档
Ctrl + ] / [ 增加/减少缩进
Ctrl + / 注释/取消注释

基本操作

创建和保存

1
2
3
4
5
6
7
8
9
10
11
12
# 新建 Notebook
# File → New Notebook → 选择 Kernel

# 重命名
# File → Rename

# 保存
# File → Save and Checkpoint
# 或 Ctrl + S

# 下载
# File → Download as → 选择格式

运行代码

1
2
3
4
5
6
7
8
9
10
11
# 单个单元格
print("Hello, Jupyter!")

# 运行所有单元格
# Kernel → Run All

# 运行到指定单元格
# Cell → Run Above / Run Below

# 重启 Kernel 并运行所有
# Kernel → Restart & Run All

扩展管理

安装扩展

1
2
3
4
5
6
7
8
9
10
11
12
# 查看已安装的扩展
jupyter labextension list

# 安装扩展
jupyter labextension install @jupyterlab/git
jupyter labextension install @jupyterlab/toc

# 卸载扩展
jupyter labextension uninstall @jupyterlab/git

# 更新扩展
jupyter labextension update

常用扩展

扩展名称 功能
@jupyterlab/git Git 版本控制集成
@jupyterlab/toc 自动生成目录
@jupyter-widgets/jupyterlab-manager 交互式 widgets 支持
@ryantam626/jupyterlab_code_formatter 代码格式化
@krassowski/jupyterlab-lsp 语言服务器协议支持

高级功能

魔术命令

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# 查看可用魔术命令
%lsmagic

# 计时
%time sum(range(1000))
%timeit sum(range(1000))

# 查看历史
%history

# 运行外部脚本
%run script.py

# 加载外部文件
%load script.py

# 保存单元格内容
%%writefile output.py
print("Hello")

# 查看工作目录
%pwd
%ls

# 切换工作目录
%cd /path/to/dir

# 环境变量
%env
%env MY_VAR=value

# matplotlib 集成
%matplotlib inline
%matplotlib notebook

系统命令

1
2
3
4
5
6
7
8
9
10
11
12
# 在单元格中执行系统命令
!ls
!pwd
!pip install package

# 捕获输出
files = !ls *.py
print(files)

# 传递 Python 变量
filename = "test.py"
!cat {filename}

Widgets 交互

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import ipywidgets as widgets
from IPython.display import display

# 滑块
slider = widgets.IntSlider(
value=50,
min=0,
max=100,
description='Value:'
)
display(slider)

# 文本框
text = widgets.Text(
value='Hello',
description='Input:'
)
display(text)

# 按钮
button = widgets.Button(description='Click Me')
display(button)

# 绑定事件
def on_button_clicked(b):
print("Button clicked!")

button.on_click(on_button_clicked)

# 组合 widgets
ui = widgets.HBox([slider, text])
display(ui)

NumPy

NumPy(Numerical Python)是 Python 科学计算的基础库,提供了高性能的多维数组对象和丰富的数学函数它是数据分析机器学习和科学计算的核心工具

  • 核心功能
    • 多维数组对象:ndarray,支持高效的数值计算
    • 广播机制:自动处理不同形状数组的运算
    • 数学函数:线性代数傅里叶变换随机数生成等
    • 集成能力:与 C/C++/Fortran 代码无缝集成
  • 主要优势
    • 性能卓越:底层使用 C 语言实现,比原生 Python 快数十倍
    • 内存高效:连续存储数据,减少内存开销
    • 功能丰富:提供全面的数学运算和统计函数
    • 生态完善:是 PandasSciPyScikit-learn 等库的基础

核心组件

  • ndarray
    • N 维数组对象,NumPy 的核心数据结构
    • 所有元素类型相同,存储在连续的内存块中
  • ufunc(通用函数)
    • 对数组进行逐元素操作的函数
    • 支持广播机制,自动处理不同形状的数组
  • 索引和切片
    • 支持整数索引布尔索引花式索引
    • 提供强大的数据选择和过滤能力
  • 广播机制
    • 自动扩展不同形状的数组以进行运算
    • 避免显式循环,提高代码效率

数据类型

  • 基本类型
    • int8, int16, int32, int64:整数类型
    • float16, float32, float64:浮点数类型
    • bool:布尔类型
    • complex64, complex128:复数类型
  • 特殊类型
    • datetime64:日期时间类型
    • timedelta64:时间差类型
    • object:Python 对象类型

工作原理

内存布局

1
2
3
4
5
6
7
8
9
10
ndarray 内部结构:
1. 数据指针:指向连续的内存块
2. 形状(shape):数组的维度信息
3. 步长(strides):每个维度移动一个元素需要的字节数
4. 数据类型(dtype):数组元素的数据类型

优势:
- 连续内存访问,缓存友好
- 向量化运算,利用 CPU SIMD 指令
- 避免 Python 对象的 overhead

运算流程

1
2
3
4
5
6
7
创建数组 → 执行运算 → 结果存储

关键步骤:
- 创建:分配连续内存,填充数据
- 运算:调用底层 C/Fortran 函数
- 广播:自动调整数组形状
- 返回:创建新的 ndarray 或修改原数组

环境搭建

pip 安装

1
2
3
4
5
6
7
8
# 基础安装
pip install numpy

# 指定版本安装
pip install numpy==1.24.3

# 升级安装
pip install --upgrade numpy

conda 安装

1
2
3
4
5
# 使用 conda 安装
conda install numpy

# 指定环境安装
conda install -n myenv numpy

验证安装

1
2
3
import numpy as np
print(np.__version__) # 输出版本号
print(np.show_config()) # 显示配置信息

创建数组

从列表创建

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import numpy as np

# 一维数组
arr1 = np.array([1, 2, 3, 4, 5])
print(arr1) # [1 2 3 4 5]
print(type(arr1)) # <class 'numpy.ndarray'>

# 二维数组
arr2 = np.array([[1, 2, 3], [4, 5, 6]])
print(arr2)
# [[1 2 3]
# [4 5 6]]

# 指定数据类型
arr3 = np.array([1, 2, 3], dtype=np.float64)
print(arr3.dtype) # float64

常用创建函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# 创建全零数组
zeros = np.zeros((3, 4))
print(zeros)
# [[0. 0. 0. 0.]
# [0. 0. 0. 0.]
# [0. 0. 0. 0.]]

# 创建全一数组
ones = np.ones((2, 3))
print(ones)
# [[1. 1. 1.]
# [1. 1. 1.]]

# 创建单位矩阵
eye = np.eye(3)
print(eye)
# [[1. 0. 0.]
# [0. 1. 0.]
# [0. 0. 1.]]

# 创建等差数列
arange = np.arange(0, 10, 2)
print(arange) # [0 2 4 6 8]

# 创建等间隔数组
linspace = np.linspace(0, 1, 5)
print(linspace) # [0. 0.25 0.5 0.75 1. ]

# 创建随机数组
random = np.random.rand(3, 3)
print(random)

# 创建正态分布随机数
normal = np.random.randn(3, 3)
print(normal)

# 创建指定范围的随机整数
randint = np.random.randint(0, 10, (3, 3))
print(randint)

特殊数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 对角矩阵
diag = np.diag([1, 2, 3])
print(diag)
# [[1 0 0]
# [0 2 0]
# [0 0 3]]

# 全填充数组
full = np.full((2, 3), 7)
print(full)
# [[7 7 7]
# [7 7 7]]

# 空数组(未初始化)
empty = np.empty((2, 2))
print(empty) # 值是随机的

# 从现有数组创建副本
arr = np.array([1, 2, 3])
copy_arr = arr.copy()

数组属性

基本属性

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
arr = np.array([[1, 2, 3], [4, 5, 6]])

# 维度数
print(arr.ndim) # 2

# 形状
print(arr.shape) # (2, 3)

# 元素总数
print(arr.size) # 6

# 数据类型
print(arr.dtype) # int32 或 int64

# 每个元素的字节数
print(arr.itemsize) # 4 或 8

# 总字节数
print(arr.nbytes) # 24 或 48

重塑数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
arr = np.arange(12)
print(arr) # [ 0 1 2 3 4 5 6 7 8 9 10 11]

# 重塑形状
reshaped = arr.reshape(3, 4)
print(reshaped)
# [[ 0 1 2 3]
# [ 4 5 6 7]
# [ 8 9 10 11]]

# 自动推断维度
auto_reshape = arr.reshape(2, -1)
print(auto_reshape)
# [[ 0 1 2 3 4 5]
# [ 6 7 8 9 10 11]]

# 展平数组
flattened = reshaped.flatten()
print(flattened) # [ 0 1 2 3 4 5 6 7 8 9 10 11]

# 转置
transposed = reshaped.T
print(transposed)
# [[ 0 4 8]
# [ 1 5 9]
# [ 2 6 10]
# [ 3 7 11]]

数组索引

基本索引

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
arr = np.array([10, 20, 30, 40, 50])

# 单个元素索引
print(arr[0]) # 10
print(arr[-1]) # 50

# 切片
print(arr[1:4]) # [20 30 40]
print(arr[::2]) # [10 30 50]
print(arr[::-1]) # [50 40 30 20 10]

# 二维数组索引
arr2d = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(arr2d[0, 1]) # 2
print(arr2d[1, :]) # [4 5 6]
print(arr2d[:, 2]) # [3 6 9]

布尔索引

1
2
3
4
5
6
7
8
9
10
11
12
13
14
arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

# 条件筛选
mask = arr > 5
print(mask) # [False False False False False True True True True True]
print(arr[mask]) # [ 6 7 8 9 10]

# 复合条件
print(arr[(arr > 3) & (arr < 8)]) # [4 5 6 7]
print(arr[(arr < 3) | (arr > 8)]) # [ 1 2 9 10]

# 修改符合条件的元素
arr[arr > 5] = 0
print(arr) # [1 2 3 4 5 0 0 0 0 0]

花式索引

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
arr = np.array([10, 20, 30, 40, 50])

# 使用索引数组
indices = [0, 2, 4]
print(arr[indices]) # [10 30 50]

# 二维数组花式索引
arr2d = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
rows = [0, 2]
cols = [1, 2]
print(arr2d[rows, cols]) # [2 9]

# 选择特定行
print(arr2d[[0, 2]])
# [[1 2 3]
# [7 8 9]]

数组运算

拼接数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])

# 水平拼接
h_stack = np.hstack((a, b))
print(h_stack) # [1 2 3 4 5 6]

# 垂直拼接(需要二维数组)
a2d = np.array([[1, 2, 3]])
b2d = np.array([[4, 5, 6]])
v_stack = np.vstack((a2d, b2d))
print(v_stack)
# [[1 2 3]
# [4 5 6]]

# 通用拼接
concat = np.concatenate([a, b])
print(concat) # [1 2 3 4 5 6]

# 沿指定轴拼接
arr1 = np.array([[1, 2], [3, 4]])
arr2 = np.array([[5, 6], [7, 8]])
print(np.concatenate([arr1, arr2], axis=0)) # 按行拼接
# [[1 2]
# [3 4]
# [5 6]
# [7 8]]
print(np.concatenate([arr1, arr2], axis=1)) # 按列拼接
# [[1 2 5 6]
# [3 4 7 8]]

分割数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])

# 等分
split = np.split(arr, 3)
print(split) # [array([1, 2, 3]), array([4, 5, 6]), array([7, 8, 9])]

# 在指定位置分割
split2 = np.split(arr, [3, 6])
print(split2) # [array([1, 2, 3]), array([4, 5, 6]), array([7, 8, 9])]

# 水平分割
arr2d = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
hsplit = np.hsplit(arr2d, 2)
print(hsplit)
# [array([[1, 2],
# [5, 6]]),
# array([[3, 4],
# [7, 8]])]

# 垂直分割
vsplit = np.vsplit(arr2d, 2)
print(vsplit)
# [array([[1, 2, 3, 4]]),
# array([[5, 6, 7, 8]])]

算术运算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
a = np.array([1, 2, 3, 4])
b = np.array([10, 20, 30, 40])

# 加法
print(a + b) # [11 22 33 44]
print(np.add(a, b)) # [11 22 33 44]

# 减法
print(b - a) # [ 9 18 27 36]
print(np.subtract(b, a)) # [ 9 18 27 36]

# 乘法
print(a * b) # [ 10 40 90 160]
print(np.multiply(a, b)) # [ 10 40 90 160]

# 除法
print(b / a) # [10. 10. 10. 10.]
print(np.divide(b, a)) # [10. 10. 10. 10.]

# 幂运算
print(a ** 2) # [ 1 4 9 16]
print(np.power(a, 2)) # [ 1 4 9 16]

# 取模
print(b % a) # [0 0 0 0]
print(np.mod(b, a)) # [0 0 0 0]

比较运算

1
2
3
4
5
6
7
8
9
10
11
a = np.array([1, 2, 3, 4, 5])
b = np.array([5, 4, 3, 2, 1])

# 比较运算
print(a > b) # [False False False True True]
print(a == b) # [False False True False False]
print(a != b) # [ True True False True True]

# 最大值和最小值
print(np.maximum(a, b)) # [5 4 3 4 5]
print(np.minimum(a, b)) # [1 2 3 2 1]

集合运算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
a = np.array([1, 2, 3, 4, 5])
b = np.array([4, 5, 6, 7, 8])

# 交集
print(np.intersect1d(a, b)) # [4 5]

# 并集
print(np.union1d(a, b)) # [1 2 3 4 5 6 7 8]

# 差集
print(np.setdiff1d(a, b)) # [1 2 3]

# 对称差集
print(np.setxor1d(a, b)) # [1 2 3 6 7 8]

# 判断元素是否在集合中
print(np.in1d([1, 6, 9], a)) # [ True False False]

矩阵运算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
A = np.array([[1, 2], [3, 4]])
B = np.array([[5, 6], [7, 8]])

# 矩阵乘法
print(np.dot(A, B))
# [[19 22]
# [43 50]]
print(A @ B) # Python 3.5+ 支持的运算符
# [[19 22]
# [43 50]]

# 转置
print(A.T)
# [[1 3]
# [2 4]]

# 逆矩阵
A_inv = np.linalg.inv(A)
print(A_inv)
# [[-2. 1. ]
# [ 1.5 -0.5]]

# 验证:A * A_inv = I
print(np.dot(A, A_inv))
# [[1. 0.]
# [0. 1.]]

# 行列式
print(np.linalg.det(A)) # -2.0

# 秩
print(np.linalg.matrix_rank(A)) # 2

# 迹
print(np.trace(A)) # 5

数学函数

三角函数

1
2
3
4
5
6
7
8
9
10
11
12
13
angles = np.array([0, np.pi/2, np.pi, 3*np.pi/2])

# 正弦
print(np.sin(angles)) # [ 0. 1. 0. -1.]

# 余弦
print(np.cos(angles)) # [ 1. 0. -1. 0.]

# 正切
print(np.tan(angles)) # [ 0. inf -0. inf]

# 反正弦
print(np.arcsin([0, 1, -1])) # [ 0. 1.57079633 -1.57079633]

指数对数

1
2
3
4
5
6
7
8
9
10
11
12
13
arr = np.array([1, 2, 3, 4])

# 自然指数
print(np.exp(arr)) # [ 2.71828183 7.3890561 20.08553692 54.59815003]

# 自然对数
print(np.log(arr)) # [0. 0.69314718 1.09861229 1.38629436]

# 以 10 为底的对数
print(np.log10(arr)) # [0. 0.30103 0.47712125 0.60205999]

# 以 2 为底的对数
print(np.log2(arr)) # [0. 1. 1.5849625 2. ]

舍入函数

1
2
3
4
5
6
7
8
9
10
11
12
13
arr = np.array([1.234, 2.567, 3.999, -1.5])

# 向下取整
print(np.floor(arr)) # [ 1. 2. 3. -2.]

# 向上取整
print(np.ceil(arr)) # [ 2. 3. 4. -1.]

# 四舍五入
print(np.round(arr, 2)) # [ 1.23 2.57 4. -1.5 ]

# 截断小数
print(np.trunc(arr)) # [ 1. 2. 3. -1.]

统计函数

基本统计

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

# 求和
print(np.sum(arr)) # 55

# 平均值
print(np.mean(arr)) # 5.5

# 中位数
print(np.median(arr)) # 5.5

# 标准差
print(np.std(arr)) # 2.8722813232690143

# 方差
print(np.var(arr)) # 8.25

# 最小值和最大值
print(np.min(arr)) # 1
print(np.max(arr)) # 10

# 最小值和最大值的索引
print(np.argmin(arr)) # 0
print(np.argmax(arr)) # 9

轴向上的统计

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
arr2d = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 按列求和(axis=0)
print(np.sum(arr2d, axis=0)) # [12 15 18]

# 按行求和(axis=1)
print(np.sum(arr2d, axis=1)) # [ 6 15 24]

# 按列求平均值
print(np.mean(arr2d, axis=0)) # [4. 5. 6.]

# 按行求最大值
print(np.max(arr2d, axis=1)) # [3 6 9]

# 累积求和
print(np.cumsum(arr2d)) # [ 1 3 6 10 15 21 28 36 45]

唯一值和计数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
arr = np.array([1, 2, 2, 3, 3, 3, 4, 4, 4, 4])

# 唯一值
unique = np.unique(arr)
print(unique) # [1 2 3 4]

# 唯一值及其出现次数
unique, counts = np.unique(arr, return_counts=True)
print(unique) # [1 2 3 4]
print(counts) # [1 2 3 4]

# 唯一值的索引
unique, indices = np.unique(arr, return_index=True)
print(indices) # [0 1 3 6]

排序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
arr = np.array([3, 1, 4, 1, 5, 9, 2, 6])

# 排序(返回新数组)
sorted_arr = np.sort(arr)
print(sorted_arr) # [1 1 2 3 4 5 6 9]

# 原地排序
arr.sort()
print(arr) # [1 1 2 3 4 5 6 9]

# 获取排序后的索引
arr = np.array([3, 1, 4, 1, 5])
indices = np.argsort(arr)
print(indices) # [1 3 0 2 4]
print(arr[indices]) # [1 1 3 4 5]

# 二维数组排序
arr2d = np.array([[3, 1, 2], [6, 5, 4]])
print(np.sort(arr2d, axis=1)) # 按行排序
# [[1 2 3]
# [4 5 6]]

高级操作

广播机制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 标量与数组运算
arr = np.array([1, 2, 3, 4])
print(arr + 10) # [11 12 13 14]
print(arr * 2) # [2 4 6 8]

# 不同形状的数组运算
a = np.array([[1, 2, 3], [4, 5, 6]])
b = np.array([10, 20, 30])
print(a + b)
# [[11 22 33]
# [14 25 36]]

# 列向量广播
c = np.array([[10], [20]])
print(a + c)
# [[11 12 13]
# [24 25 26]]

# 广播规则:
# 1. 如果维度数不同,在较小的数组前面补 1
# 2. 如果某个维度大小为 1,则沿该维度复制
# 3. 如果维度大小不同且都不为 1,则报错

特征向量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
A = np.array([[4, 2], [1, 3]])

# 特征值和特征向量
eigenvalues, eigenvectors = np.linalg.eig(A)
print("特征值:", eigenvalues) # [5. 2.]
print("特征向量:\n", eigenvectors)
# [[ 0.89442719 -0.70710678]
# [ 0.4472136 0.70710678]]

# 验证:A * v = λ * v
for i in range(len(eigenvalues)):
v = eigenvectors[:, i]
λ = eigenvalues[i]
print(f"A·v{i} = {np.dot(A, v)}")
print(f"λ{i}·v{i} = {λ * v}")

线性方程组

1
2
3
4
5
6
7
8
9
10
# Ax = b
A = np.array([[2, 1], [1, 3]])
b = np.array([4, 5])

# 求解 x
x = np.linalg.solve(A, b)
print(x) # [1. 2.]

# 验证
print(np.dot(A, x)) # [4. 5.]

SVD 分解

1
2
3
4
5
6
7
8
9
10
11
12
13
A = np.array([[1, 2], [3, 4], [5, 6]])

# SVD 分解:A = U * S * Vt
U, S, Vt = np.linalg.svd(A)
print("U:\n", U)
print("S:", S)
print("Vt:\n", Vt)

# 重构矩阵
S_matrix = np.zeros(A.shape)
S_matrix[:len(S), :len(S)] = np.diag(S)
reconstructed = U @ S_matrix @ Vt
print("重构矩阵:\n", reconstructed)

基本随机数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 设置随机种子(保证可重复性)
np.random.seed(42)

# 0-1 之间的均匀分布随机数
print(np.random.rand(3, 3))

# 标准正态分布随机数
print(np.random.randn(3, 3))

# 指定范围的随机整数
print(np.random.randint(0, 10, (3, 3)))

# 指定范围的均匀分布随机数
print(np.random.uniform(0, 10, (3, 3)))

概率分布

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 二项分布
binomial = np.random.binomial(n=10, p=0.5, size=1000)
print(np.mean(binomial)) # 接近 5

# 泊松分布
poisson = np.random.poisson(lam=5, size=1000)
print(np.mean(poisson)) # 接近 5

# 正态分布
normal = np.random.normal(loc=0, scale=1, size=1000)
print(np.mean(normal)) # 接近 0
print(np.std(normal)) # 接近 1

# 贝塔分布
beta = np.random.beta(a=2, b=5, size=1000)

# 伽马分布
gamma = np.random.gamma(shape=2, scale=2, size=1000)

随机选择

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

# 随机选择一个元素
print(np.random.choice(arr))

# 随机选择多个元素(可重复)
print(np.random.choice(arr, size=5))

# 随机选择多个元素(不重复)
print(np.random.choice(arr, size=5, replace=False))

# 带权重的随机选择
weights = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
print(np.random.choice(arr, size=5, p=weights))

# 随机打乱数组
np.random.shuffle(arr)
print(arr)

# 随机排列(返回新数组)
permuted = np.random.permutation(arr)
print(permuted)

性能优化

命名规范

1
2
3
4
5
6
7
8
9
10
11
12
13
# 导入约定
import numpy as np

# 变量命名
arr = np.array([1, 2, 3]) # 一般数组
matrix = np.array([[1, 2], [3, 4]]) # 矩阵
vector = np.array([1, 2, 3]) # 向量

# 常量命名
PI = np.pi
E = np.e
INF = np.inf
NAN = np.nan

避免循环

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import time

# 低效:使用 Python 循环
def sum_with_loop(arr):
result = 0
for x in arr:
result += x ** 2
return result

# 高效:使用 NumPy 向量化
def sum_with_numpy(arr):
return np.sum(arr ** 2)

# 测试
arr = np.random.rand(1000000)

start = time.time()
result1 = sum_with_loop(arr)
print(f"循环耗时: {time.time() - start:.4f}秒")

start = time.time()
result2 = sum_with_numpy(arr)
print(f"向量化耗时: {time.time() - start:.4f}秒")

print(f"速度提升: {(time.time() - start) / (time.time() - start):.0f}倍")

广播优化

1
2
3
4
5
6
7
8
9
10
11
12
# 低效:使用循环
def normalize_rows_loop(arr):
result = np.zeros_like(arr)
for i in range(arr.shape[0]):
result[i] = arr[i] / np.sum(arr[i])
return result

# 高效:使用广播
def normalize_rows_broadcast(arr):
return arr / np.sum(arr, axis=1, keepdims=True)

arr = np.random.rand(1000, 100)

数据类型选择

1
2
3
4
5
6
7
8
# 根据数据范围选择合适的类型
arr_int8 = np.array([1, 2, 3], dtype=np.int8) # 节省内存
arr_int32 = np.array([1, 2, 3], dtype=np.int32)
arr_float32 = np.array([1.0, 2.0, 3.0], dtype=np.float32) # 精度足够时使用

print(f"int8 占用: {arr_int8.nbytes} 字节")
print(f"int32 占用: {arr_int32.nbytes} 字节")
print(f"float32 占用: {arr_float32.nbytes} 字节")

原地操作

1
2
3
4
5
6
7
8
9
10
arr = np.array([1, 2, 3, 4, 5])

# 创建新数组(占用更多内存)
result = arr * 2

# 原地操作(节省内存)
arr *= 2
print(arr) # [2 4 6 8 10]

# 注意:原地操作会修改原数组

视图而非副本

1
2
3
4
5
6
7
8
9
10
11
arr = np.array([1, 2, 3, 4, 5])

# 切片返回视图(不复制数据)
view = arr[1:4]
view[0] = 100
print(arr) # [ 1 100 3 4 5] # 原数组也被修改

# 显式复制
copy = arr[1:4].copy()
copy[0] = 200
print(arr) # [ 1 100 3 4 5] # 原数组不变

使用多线程

1
2
3
4
5
# NumPy 内部已经使用多线程优化
# 对于大型矩阵运算,会自动利用多核 CPU

# 查看 BLAS/LAPACK 配置
np.__config__.show()

使用 Numba 加速

1
2
3
4
5
6
7
8
9
10
11
from numba import jit

@jit(nopython=True)
def fast_computation(arr):
result = 0
for i in range(len(arr)):
result += arr[i] ** 2
return result

arr = np.random.rand(1000000)
result = fast_computation(arr)

其他优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# 1. 优先使用向量化运算
# 好
result = arr1 + arr2
# 不好
result = np.array([a + b for a, b in zip(arr1, arr2)])

# 2. 避免不必要的复制
# 好
view = arr[1:10]
# 不好
copy = arr[1:10].copy() # 除非确实需要副本

# 3. 使用合适的数据类型
# 好
arr = np.array([1, 2, 3], dtype=np.int32)
# 不好
arr = np.array([1, 2, 3], dtype=np.int64) # 如果不需要这么大的范围

# 4. 预分配数组
# 好
result = np.zeros(1000)
for i in range(1000):
result[i] = i ** 2
# 不好
result = []
for i in range(1000):
result.append(i ** 2)
result = np.array(result)

# 5. 使用内置函数
# 好
mean = np.mean(arr)
# 不好
mean = np.sum(arr) / len(arr)

使用示例

图像表示

1
2
3
4
5
6
7
8
9
10
11
12
# 灰度图像:二维数组
gray_image = np.random.randint(0, 256, (100, 100), dtype=np.uint8)
print(gray_image.shape) # (100, 100)

# 彩色图像:三维数组(高度, 宽度, 通道)
color_image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
print(color_image.shape) # (100, 100, 3)

# 提取颜色通道
red_channel = color_image[:, :, 0]
green_channel = color_image[:, :, 1]
blue_channel = color_image[:, :, 2]

图像操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 图像翻转
flipped_lr = np.fliplr(color_image) # 左右翻转
flipped_ud = np.flipud(color_image) # 上下翻转

# 图像裁剪
cropped = color_image[20:80, 20:80, :]

# 图像缩放(简单示例)
from scipy.ndimage import zoom
scaled = zoom(color_image, (0.5, 0.5, 1))

# 灰度化
gray = np.mean(color_image, axis=2).astype(np.uint8)

# 二值化
threshold = 128
binary = (gray > threshold).astype(np.uint8) * 255

数据清洗

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 创建包含缺失值的数据
data = np.array([1, 2, np.nan, 4, np.nan, 6, 7, 8, 9, 10])

# 检测缺失值
print(np.isnan(data)) # [False False True False True False False False False False]

# 删除缺失值
clean_data = data[~np.isnan(data)]
print(clean_data) # [ 1. 2. 4. 6. 7. 8. 9. 10.]

# 填补缺失值
mean_value = np.nanmean(data)
filled_data = np.where(np.isnan(data), mean_value, data)
print(filled_data)

# 异常值检测
data = np.random.randn(1000)
mean = np.mean(data)
std = np.std(data)
outliers = data[np.abs(data - mean) > 3 * std]
print(f"异常值数量: {len(outliers)}")

数据统计分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# 创建示例数据
np.random.seed(42)
heights = np.random.normal(170, 10, 1000) # 身高(cm)
weights = np.random.normal(65, 15, 1000) # 体重(kg)

# 基本统计
print(f"平均身高: {np.mean(heights):.2f} cm")
print(f"身高标准差: {np.std(heights):.2f} cm")
print(f"身高范围: [{np.min(heights):.2f}, {np.max(heights):.2f}]")

# 相关系数
correlation = np.corrcoef(heights, weights)[0, 1]
print(f"身高体重相关系数: {correlation:.4f}")

# 百分位数
print(f"身高第 25 百分位: {np.percentile(heights, 25):.2f}")
print(f"身高中位数: {np.median(heights):.2f}")
print(f"身高第 75 百分位: {np.percentile(heights, 75):.2f}")

# 直方图
hist, bin_edges = np.histogram(heights, bins=20)
print(f"直方图频数: {hist}")
print(f" bins 边界: {bin_edges}")

特征标准化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Min-Max 标准化
def min_max_normalize(X):
min_val = np.min(X, axis=0)
max_val = np.max(X, axis=0)
return (X - min_val) / (max_val - min_val)

# Z-Score 标准化
def z_score_normalize(X):
mean = np.mean(X, axis=0)
std = np.std(X, axis=0)
return (X - mean) / std

# 示例
X = np.random.rand(100, 5) * 100
X_normalized = z_score_normalize(X)
print(f"标准化后均值: {np.mean(X_normalized, axis=0)}")
print(f"标准化后标准差: {np.std(X_normalized, axis=0)}")

数据增强

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# 图像数据增强示例
def augment_image(image):
"""简单的图像增强"""
augmented = []

# 原始图像
augmented.append(image)

# 水平翻转
augmented.append(np.fliplr(image))

# 旋转 90 度
augmented.append(np.rot90(image))

# 添加噪声
noise = np.random.normal(0, 0.1, image.shape)
augmented.append(np.clip(image + noise, 0, 1))

return np.array(augmented)

# 示例
image = np.random.rand(64, 64, 3)
augmented_images = augment_image(image)
print(f"增强后数据形状: {augmented_images.shape}") # (4, 64, 64, 3)

常见问题

可变默认参数

1
2
3
4
5
6
7
8
9
10
11
# 错误示例
def add_to_array(arr=np.array([])):
arr = np.append(arr, 1)
return arr

# 正确示例
def add_to_array(arr=None):
if arr is None:
arr = np.array([])
arr = np.append(arr, 1)
return arr

浮点数精度

1
2
3
4
5
6
7
8
# 不要直接比较浮点数
a = 0.1 + 0.2
b = 0.3
print(a == b) # False

# 使用容差比较
print(np.isclose(a, b)) # True
print(np.allclose([a], [b])) # True

视图和副本

1
2
3
4
5
6
7
8
9
10
11
arr = np.array([1, 2, 3, 4, 5])

# 切片返回视图
view = arr[1:4]
view[0] = 100
print(arr) # [ 1 100 3 4 5] # 原数组被修改

# 如果需要独立副本,使用 .copy()
copy = arr[1:4].copy()
copy[0] = 200
print(arr) # [ 1 100 3 4 5] # 原数组不变

形状不匹配错误

1
2
3
4
5
6
7
8
# 错误:形状不匹配
a = np.array([1, 2, 3])
b = np.array([1, 2])
# print(a + b) # ValueError

# 解决方案:确保形状兼容或使用广播
b_padded = np.pad(b, (0, 1), 'constant')
print(a + b_padded) # [2 4 3]

数据类型转换

1
2
3
4
5
6
7
8
9
10
11
# 隐式类型转换
arr = np.array([1, 2, 3])
arr_float = arr / 2 # 自动转换为 float
print(arr_float.dtype) # float64

# 显式类型转换
arr_int = arr_float.astype(np.int32)
print(arr_int) # [0 1 1]
print(arr_int.dtype) # int32

# 注意:转换可能丢失精度

内存不足问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 问题:处理大型数组时内存不足
# large_arr = np.ones((10000, 10000, 10)) # 可能需要大量内存

# 解决方案:
# 1. 使用更小的数据类型
arr = np.ones((10000, 10000), dtype=np.float32) # 节省一半内存

# 2. 分块处理
chunk_size = 1000
for i in range(0, 10000, chunk_size):
chunk = arr[i:i+chunk_size, :]
# 处理 chunk

# 3. 使用内存映射
mmap_arr = np.memmap('large_array.dat', dtype='float32', mode='w+', shape=(10000, 10000))

Pandas

Pandas 是一个强大的 Python 数据分析库提供了高性能易用的数据结构和数据分析工具它基于 NumPy 构建能够处理结构化数据是数据科学和机器学习领域的核心工具

  • 核心功能
    • 数据结构Series 和 DataFrame 两种核心数据结构
    • 数据读写支持 CSVExcelSQLJSON 等多种格式
    • 数据清洗处理缺失值重复值异常值
    • 数据转换重塑透视合并分组聚合
    • 时间序列强大的日期和时间处理功能
  • 主要优势
    • 高效灵活基于 NumPy性能优异
    • 功能丰富提供全面的数据操作 API
    • 易于学习API 设计直观文档完善
    • 生态完善与 NumPyMatplotlibScikit-learn 等无缝集成

数据结构

  • Series
    • 一维标签数组类似带索引的列表
    • 可以存储任何数据类型
  • DataFrame
    • 二维表格型数据结构类似 Excel 表格或 SQL 表
    • 由多个 Series 组成每列可以是不同类型
  • Index
    • 行和列的标签支持多种类型
    • 提供快速查找和对齐功能

核心组件

  • 数据读取器
    • read_csvread_excelread_sql 等
    • 支持多种数据源和格式
  • 数据处理器
    • 清洗转换聚合等操作
    • 支持向量化运算性能优异
  • 分组引擎
    • groupby 操作类似 SQL 的 GROUP BY
    • 支持复杂的分组聚合计算

工作原理

数据处理流程

1
2
3
4
5
6
7
1. 数据加载 → 从文件/数据库/API 读取数据
2. 数据探索 → 查看数据结构统计信息
3. 数据清洗 → 处理缺失值异常值重复值
4. 数据转换 → 格式化重塑合并
5. 数据分析 → 统计分析分组聚合
6. 数据可视化 → 绘制图表展示结果
7. 数据导出 → 保存为文件或写入数据库

内存管理机制

1
2
3
4
5
6
7
数据加载 → 创建 DataFrame → 内存中操作 → 垃圾回收

关键特点
- 惰性加载大文件可分块读取
- 视图与副本某些操作返回视图而非副本
- 内存优化使用合适的数据类型减少内存占用
- 自动回收Python 垃圾回收机制管理内存

环境搭建

pip 安装

1
2
3
4
5
6
7
8
9
10
11
# 基础安装
pip install pandas

# 安装完整功能包含所有可选依赖
pip install pandas[all]

# 指定版本安装
pip install pandas==2.1.4

# 升级 Pandas
pip install --upgrade pandas

conda 安装

1
2
3
4
5
6
7
8
# 使用 conda 安装
conda install pandas

# 指定版本
conda install pandas=2.1.4

# 在指定环境中安装
conda install -n myenv pandas

验证安装

1
2
3
4
5
6
import pandas as pd
print(pd.__version__)
# 输出: 2.1.4

# 查看详细信息
pd.show_versions()

依赖环境

1
2
3
4
5
6
7
8
9
10
11
12
pandas 主要依赖
- numpy: 数值计算基础
- python-dateutil: 日期时间处理
- pytz: 时区处理
- tzdata: 时区数据

可选依赖
- openpyxl: Excel 文件读写
- xlrd: 旧版 Excel 文件读取
- sqlalchemy: 数据库连接
- matplotlib: 数据可视化
- scipy: 科学计算

💗💗 requirements.txt

1
2
3
4
5
pandas==2.1.4
numpy==1.26.2
openpyxl==3.1.2
matplotlib==3.8.2
sqlalchemy==2.0.23

基本配置

💗💗 显示选项配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import pandas as pd

# 设置显示最大行数
pd.set_option('display.max_rows', 100)

# 设置显示最大列数
pd.set_option('display.max_columns', 20)

# 设置列宽
pd.set_option('display.max_colwidth', 50)

# 设置浮点数显示精度
pd.set_option('display.float_format', lambda x: '%.2f' % x)

# 恢复默认设置
pd.reset_option('all')

# 查看当前所有选项
pd.describe_option()

💗💗 常用导入约定

1
2
3
4
5
6
7
8
9
10
11
# 标准导入方式
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Jupyter Notebook 中显示图表
%matplotlib inline

# 设置中文显示
plt.rcParams['font.sans-serif'] = ['SimHei'] # Windows
plt.rcParams['axes.unicode_minus'] = False

创建测试数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import pandas as pd
import numpy as np

# 从字典创建
data = {
'name': ['张三', '李四', '王五', '赵六'],
'age': [25, 30, 28, 35],
'city': ['北京', '上海', '广州', '深圳'],
'salary': [8000, 12000, 10000, 15000]
}
df = pd.DataFrame(data)
print(df)

# 从列表创建
data_list = [
['张三', 25, '北京', 8000],
['李四', 30, '上海', 12000],
['王五', 28, '广州', 10000],
['赵六', 35, '深圳', 15000]
]
df = pd.DataFrame(data_list, columns=['name', 'age', 'city', 'salary'])

# 生成随机数据
np.random.seed(42)
df_random = pd.DataFrame({
'A': np.random.randn(100),
'B': np.random.randint(0, 100, 100),
'C': np.random.choice(['男', '女'], 100),
'D': pd.date_range('2024-01-01', periods=100)
})


# 创建示例 CSV 文件
csv_content = """name,age,city,salary
张三,25,北京,8000
李四,30,上海,12000
王五,28,广州,10000
赵六,35,深圳,15000"""

with open('sample.csv', 'w', encoding='utf-8') as f:
f.write(csv_content)

# 读取 CSV 文件
df = pd.read_csv('sample.csv')
print(df)

基础使用

创建 Series

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import pandas as pd

# 从列表创建
s = pd.Series([1, 2, 3, 4, 5])
print(s)

# 指定索引
s = pd.Series([1, 2, 3, 4, 5], index=['a', 'b', 'c', 'd', 'e'])
print(s)

# 从字典创建
data = {'a': 1, 'b': 2, 'c': 3}
s = pd.Series(data)
print(s)

# 从标量创建
s = pd.Series(5, index=['a', 'b', 'c', 'd'])
print(s)

操作 Series

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
s = pd.Series([1, 2, 3, 4, 5], index=['a', 'b', 'c', 'd', 'e'])

# 访问元素
print(s['a']) # 通过标签访问
print(s[0]) # 通过位置访问
print(s[['a', 'c']]) # 访问多个元素

# 切片
print(s['a':'c']) # 标签切片包含末尾
print(s[0:3]) # 位置切片不包含末尾

# 布尔索引
print(s[s > 2]) # 筛选大于2的值

# 基本属性
print(s.index) # 索引
print(s.values) # 值数组
print(s.dtype) # 数据类型
print(s.shape) # 形状
print(s.size) # 大小

创建 DataFrame

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import pandas as pd
import numpy as np

# 从字典创建
data = {
'name': ['张三', '李四', '王五'],
'age': [25, 30, 28],
'city': ['北京', '上海', '广州']
}
df = pd.DataFrame(data)
print(df)

# 指定索引
df = pd.DataFrame(data, index=['a', 'b', 'c'])
print(df)

# 从嵌套字典创建
data = {
'name': {'a': '张三', 'b': '李四'},
'age': {'a': 25, 'b': 30}
}
df = pd.DataFrame(data)
print(df)

# 从列表字典创建
data = [
{'name': '张三', 'age': 25},
{'name': '李四', 'age': 30}
]
df = pd.DataFrame(data)
print(df)

操作 DataFrame

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# 查看数据
print(df.head()) # 前5行
print(df.tail()) # 后5行
print(df.head(10)) # 前10行

# 基本信息
print(df.shape) # 形状 (行数, 列数)
print(df.columns) # 列名
print(df.index) # 索引
print(df.dtypes) # 各列数据类型
print(df.info()) # 详细信息
print(df.describe()) # 统计摘要

# 访问列
print(df['name']) # 单列返回 Series
print(df[['name', 'age']]) # 多列返回 DataFrame

# 访问行
print(df.loc[0]) # 通过标签访问行
print(df.iloc[0]) # 通过位置访问行
print(df.loc[0:2]) # 行切片标签
print(df.iloc[0:3]) # 行切片位置

# 访问单元格
print(df.loc[0, 'name']) # 标签访问
print(df.iloc[0, 0]) # 位置访问
print(df.at[0, 'name']) # 快速标量访问
print(df.iat[0, 0]) # 快速标量访问

读取数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# 读取 CSV
df = pd.read_csv('data.csv')
df = pd.read_csv('data.csv', encoding='utf-8')
df = pd.read_csv('data.csv', sep='\t') # Tab 分隔
df = pd.read_csv('data.csv', header=None) # 无表头
df = pd.read_csv('data.csv', names=['col1', 'col2']) # 指定列名
df = pd.read_csv('data.csv', index_col=0) # 指定索引列
df = pd.read_csv('data.csv', usecols=['name', 'age']) # 选择列
df = pd.read_csv('data.csv', nrows=100) # 读取前100行
df = pd.read_csv('data.csv', chunksize=1000) # 分块读取

# 读取 Excel
df = pd.read_excel('data.xlsx')
df = pd.read_excel('data.xlsx', sheet_name='Sheet1')
df = pd.read_excel('data.xlsx', sheet_name=0)

# 读取 JSON
df = pd.read_json('data.json')
df = pd.read_json('data.json', orient='records')

# 读取 SQL
from sqlalchemy import create_engine
engine = create_engine('sqlite:///database.db')
df = pd.read_sql('SELECT * FROM table', engine)
df = pd.read_sql_table('table', engine)
df = pd.read_sql_query('SELECT * FROM table', engine)

# 读取 HTML
dfs = pd.read_html('https://example.com/table.html')
df = dfs[0] # 返回的是列表

# 读取剪贴板
df = pd.read_clipboard()

写入数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 写入 CSV
df.to_csv('output.csv', index=False) # 不保存索引
df.to_csv('output.csv', encoding='utf-8-sig') # UTF-8 with BOM
df.to_csv('output.csv', sep='\t') # Tab 分隔

# 写入 Excel
df.to_excel('output.xlsx', index=False, sheet_name='Sheet1')

# 写入 JSON
df.to_json('output.json', orient='records', force_ascii=False)

# 写入 SQL
df.to_sql('table', engine, if_exists='replace', index=False)

# 写入剪贴板
df.to_clipboard(index=False)

数据清洗

查看数据质量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 查看缺失值
print(df.isnull().sum()) # 每列缺失值数量
print(df.isnull().sum().sum()) # 总缺失值数量
print(df.isnull().mean()) # 每列缺失比例
print(df.notnull().sum()) # 每列非缺失值数量

# 查看重复值
print(df.duplicated().sum()) # 重复行数量
print(df.duplicated(subset=['name'])) # 指定列的重复

# 查看异常值
print(df.describe()) # 统计摘要
print(df['age'].quantile([0.25, 0.5, 0.75])) # 分位数

# 数据类型检查
print(df.dtypes)
print(df.select_dtypes(include=['int64'])) # 选择整数列
print(df.select_dtypes(exclude=['object'])) # 排除字符串列

检测缺失值

1
2
3
4
5
6
7
8
9
10
11
12
13
import numpy as np

# 创建含缺失值的数据
df = pd.DataFrame({
'A': [1, 2, np.nan, 4, 5],
'B': [np.nan, 2, 3, 4, 5],
'C': [1, 2, 3, np.nan, 5]
})

# 检测缺失值
print(df.isnull()) # 返回布尔 DataFrame
print(df.notnull()) # 返回布尔 DataFrame
print(df.isna()) # 同 isnull()

删除缺失值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 删除含缺失值的行
df_clean = df.dropna()
df_clean = df.dropna(axis=0) # 明确指定行

# 删除含缺失值的列
df_clean = df.dropna(axis=1)

# 删除全部为缺失值的行/列
df_clean = df.dropna(how='all')

# 删除指定列含缺失值的行
df_clean = df.dropna(subset=['A', 'B'])

# 删除缺失值超过阈值的列
df_clean = df.dropna(thresh=3) # 至少3个非缺失值

填充缺失值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# 用固定值填充
df_filled = df.fillna(0)
df_filled = df.fillna('未知')

# 用统计值填充
df_filled = df.fillna(df.mean()) # 均值
df_filled = df.fillna(df.median()) # 中位数
df_filled = df.fillna(df.mode().iloc[0]) # 众数

# 向前/向后填充
df_filled = df.fillna(method='ffill') # 向前填充
df_filled = df.fillna(method='bfill') # 向后填充

# 插值填充
df_filled = df.interpolate() # 线性插值
df_filled = df.interpolate(method='polynomial', order=2) # 多项式插值

# 不同列用不同值填充
df_filled = df.fillna({
'A': df['A'].mean(),
'B': df['B'].median(),
'C': 0
})

处理重复值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 检测重复值
print(df.duplicated()) # 返回布尔 Series
print(df.duplicated(subset=['name'])) # 指定列

# 删除重复值
df_unique = df.drop_duplicates()
df_unique = df.drop_duplicates(subset=['name']) # 指定列
df_unique = df.drop_duplicates(keep='first') # 保留第一个
df_unique = df.drop_duplicates(keep='last') # 保留最后一个
df_unique = df.drop_duplicates(keep=False) # 删除所有重复

# 查看重复的行
duplicates = df[df.duplicated(keep=False)]
print(duplicates)

识别异常值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 使用标准差识别
mean = df['age'].mean()
std = df['age'].std()
outliers = df[(df['age'] < mean - 3*std) | (df['age'] > mean + 3*std)]

# 使用 IQR四分位距识别
Q1 = df['age'].quantile(0.25)
Q3 = df['age'].quantile(0.75)
IQR = Q3 - Q1
outliers = df[(df['age'] < Q1 - 1.5*IQR) | (df['age'] > Q3 + 1.5*IQR)]

# 使用箱线图可视化
import matplotlib.pyplot as plt
plt.boxplot(df['age'])
plt.title('Age Boxplot')
plt.show()

处理异常值

1
2
3
4
5
6
7
8
9
10
11
12
13
# 删除异常值
df_clean = df[~((df['age'] < Q1 - 1.5*IQR) | (df['age'] > Q3 + 1.5*IQR))]

# 替换为边界值
df['age'] = df['age'].clip(lower=Q1 - 1.5*IQR, upper=Q3 + 1.5*IQR)

# 替换为中位数
median = df['age'].median()
df.loc[outliers.index, 'age'] = median

# 替换为均值
mean = df['age'].mean()
df.loc[outliers.index, 'age'] = mean

数据类型转换

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 查看数据类型
print(df.dtypes)

# 转换数据类型
df['age'] = df['age'].astype(int)
df['salary'] = df['salary'].astype(float)
df['date'] = pd.to_datetime(df['date'])
df['category'] = df['category'].astype('category')

# 转换为数值类型处理非数值字符
df['value'] = pd.to_numeric(df['value'], errors='coerce') # 无法转换的变为 NaN
df['value'] = pd.to_numeric(df['value'], errors='ignore') # 无法转换的保持原样

# 分类类型优化内存
df['gender'] = df['gender'].astype('category')
print(df['gender'].cat.categories) # 查看类别
print(df.memory_usage(deep=True)) # 查看内存使用

数据过滤

列选择

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 选择单列
name = df['name']
name = df.name # 属性访问列名不能有空格或特殊字符

# 选择多列
subset = df[['name', 'age']]
subset = df.loc[:, ['name', 'age']]

# 选择列范围
subset = df.loc[:, 'name':'age']

# 按条件选择列
numeric_cols = df.select_dtypes(include=[np.number])
string_cols = df.select_dtypes(include=['object'])

行选择

💗💗 loc 标签索引

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 选择单行
row = df.loc[0]
row = df.loc['index_label']

# 选择多行
rows = df.loc[[0, 2, 4]]
rows = df.loc[0:5] # 包含末尾

# 行列同时选择
value = df.loc[0, 'name']
subset = df.loc[0:5, ['name', 'age']]

# 条件选择
subset = df.loc[df['age'] > 25]
subset = df.loc[(df['age'] > 25) & (df['city'] == '北京')]
subset = df.loc[df['name'].isin(['张三', '李四'])]

💗💗 iloc 位置索引

1
2
3
4
5
6
7
8
9
10
11
12
13
# 选择单行
row = df.iloc[0]

# 选择多行
rows = df.iloc[0:5]
rows = df.iloc[[0, 2, 4]]

# 行列同时选择
value = df.iloc[0, 0]
subset = df.iloc[0:5, 0:2]

# 步进选择
subset = df.iloc[::2] # 每隔一行

基本条件过滤

1
2
3
4
5
6
7
8
9
10
11
12
13
# 单条件
df_filtered = df[df['age'] > 25]
df_filtered = df[df['city'] == '北京']
df_filtered = df[df['name'] != '张三']

# 多条件
df_filtered = df[(df['age'] > 25) & (df['city'] == '北京')] # 且
df_filtered = df[(df['age'] > 25) | (df['city'] == '上海')] # 或
df_filtered = df[~(df['age'] > 25)] # 非

# 范围条件
df_filtered = df[df['age'].between(25, 30)]
df_filtered = df[df['salary'].between(8000, 12000)]

字符串条件过滤

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 包含
df_filtered = df[df['name'].str.contains('张')]
df_filtered = df[df['city'].str.contains('北|上')] # 正则

# 开头/结尾
df_filtered = df[df['name'].str.startswith('张')]
df_filtered = df[df['city'].str.endswith('京')]

# 长度
df_filtered = df[df['name'].str.len() > 2]

# 大小写
df_filtered = df[df['name'].str.upper() == '张三'.upper()]

# 替换
df['name_clean'] = df['name'].str.replace(' ', '')

高级过滤过滤

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# isin 过滤
cities = ['北京', '上海']
df_filtered = df[df['city'].isin(cities)]

# notin 过滤
df_filtered = df[~df['city'].isin(cities)]

# query 方法更简洁
df_filtered = df.query('age > 25 and city == "北京"')
df_filtered = df.query('age > @min_age', min_age=25) # 使用变量

# nlargest/nsmallest
top5 = df.nlargest(5, 'salary')
bottom5 = df.nsmallest(5, 'age')

数据排序

基本排序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 单列排序
df_sorted = df.sort_values(by='age')
df_sorted = df.sort_values(by='age', ascending=False) # 降序

# 多列排序
df_sorted = df.sort_values(by=['city', 'age'])
df_sorted = df.sort_values(by=['city', 'age'], ascending=[True, False])

# 按索引排序
df_sorted = df.sort_index()
df_sorted = df.sort_index(ascending=False)

# 重置索引
df_sorted = df.sort_values(by='age').reset_index(drop=True)

排名

1
2
3
4
5
6
7
8
9
10
11
12
# 基本排名
df['rank'] = df['salary'].rank()
df['rank'] = df['salary'].rank(ascending=False) # 降序排名

# 排名方法
df['rank_avg'] = df['salary'].rank(method='average') # 平均排名默认
df['rank_min'] = df['salary'].rank(method='min') # 最小排名
df['rank_max'] = df['salary'].rank(method='max') # 最大排名
df['rank_first'] = df['salary'].rank(method='first') # 首次出现顺序

# 百分比排名
df['percentile'] = df['salary'].rank(pct=True)

数据转换

添加列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 直接赋值
df['bonus'] = df['salary'] * 0.1
df['full_info'] = df['name'] + ' (' + df['city'] + ')'

# 使用 assign返回新 DataFrame
df_new = df.assign(
bonus=lambda x: x['salary'] * 0.1,
tax=lambda x: x['salary'] * 0.2
)

# 使用 apply
df['age_group'] = df['age'].apply(lambda x: '青年' if x < 30 else '中年')

# 使用 np.where
import numpy as np
df['level'] = np.where(df['salary'] > 10000, '高', '低')

# 使用 cut 分箱
df['age_bin'] = pd.cut(df['age'], bins=[0, 25, 35, 50], labels=['青年', '中年', '老年'])

删除列

1
2
3
4
5
6
7
8
9
10
11
12
13
# 删除单列
df_dropped = df.drop('bonus', axis=1)
df_dropped = df.drop(columns='bonus')

# 删除多列
df_dropped = df.drop(['bonus', 'tax'], axis=1)
df_dropped = df.drop(columns=['bonus', 'tax'])

# 原地删除
df.drop(columns=['bonus'], inplace=True)

# 按条件删除列
df_dropped = df.dropna(axis=1) # 删除含缺失值的列

重命名

1
2
3
4
5
6
7
8
9
10
11
12
13
# 重命名列
df_renamed = df.rename(columns={'name': '姓名', 'age': '年龄'})

# 重命名索引
df_renamed = df.rename(index={0: 'a', 1: 'b'})

# 使用函数重命名
df_renamed = df.rename(columns=str.upper)
df_renamed = df.rename(columns=lambda x: x.replace(' ', '_'))

# 设置索引名称
df.index.name = 'id'
df.columns.name = 'attributes'

apply 方法

1
2
3
4
5
6
7
8
9
10
11
12
# 对列应用函数
df['age_squared'] = df['age'].apply(lambda x: x ** 2)

# 对行应用函数
def custom_func(row):
return row['salary'] * 0.1 if row['age'] > 30 else row['salary'] * 0.05

df['bonus'] = df.apply(custom_func, axis=1)

# 使用内置函数
df['name_length'] = df['name'].apply(len)
df['name_upper'] = df['name'].apply(str.upper)

map 方法

1
2
3
4
5
6
7
8
9
# Series 映射
city_code = {'北京': 'BJ', '上海': 'SH', '广州': 'GZ'}
df['city_code'] = df['city'].map(city_code)

# 使用函数
df['age_group'] = df['age'].map(lambda x: '青年' if x < 30 else '中年')

# replace 方法
df['city'] = df['city'].replace({'北京': 'Beijing', '上海': 'Shanghai'})

transform 方法

1
2
3
4
5
6
7
8
9
# 分组后变换
df['salary_norm'] = df.groupby('city')['salary'].transform(
lambda x: (x - x.mean()) / x.std()
)

# 标准化
df['age_zscore'] = df.groupby('city')['age'].transform(
lambda x: (x - x.mean()) / x.std()
)

数据聚合

基本统计

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 描述性统计
print(df.describe())
print(df.describe(include='all')) # 包含所有类型

# 单列统计
print(df['age'].mean()) # 均值
print(df['age'].median()) # 中位数
print(df['age'].std()) # 标准差
print(df['age'].var()) # 方差
print(df['age'].min()) # 最小值
print(df['age'].max()) # 最大值
print(df['age'].sum()) # 求和
print(df['age'].count()) # 计数
print(df['age'].quantile(0.75)) # 分位数

# 相关系数
print(df.corr())
print(df['age'].corr(df['salary']))

# 协方差
print(df.cov())

基本分组

1
2
3
4
5
6
7
8
9
10
11
12
13
# 单列分组
grouped = df.groupby('city')
print(grouped['salary'].mean())
print(grouped['salary'].agg(['mean', 'sum', 'count']))

# 多列分组
grouped = df.groupby(['city', 'gender'])
print(grouped['salary'].mean())

# 遍历分组
for name, group in df.groupby('city'):
print(f"城市: {name}")
print(group)

聚合函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# 单个聚合
result = df.groupby('city')['salary'].mean()
result = df.groupby('city')['salary'].sum()
result = df.groupby('city')['salary'].count()

# 多个聚合
result = df.groupby('city')['salary'].agg(['mean', 'sum', 'count'])
result = df.groupby('city')['salary'].agg([
('平均值', 'mean'),
('总和', 'sum'),
('数量', 'count')
])

# 不同列不同聚合
result = df.groupby('city').agg({
'salary': ['mean', 'sum'],
'age': ['min', 'max'],
'name': 'count'
})

# 自定义聚合函数
def range_func(x):
return x.max() - x.min()

result = df.groupby('city')['salary'].agg(range_func)

分组变换

1
2
3
4
5
6
7
8
9
10
# 组内标准化
df['salary_zscore'] = df.groupby('city')['salary'].transform(
lambda x: (x - x.mean()) / x.std()
)

# 组内排名
df['salary_rank'] = df.groupby('city')['salary'].rank(ascending=False)

# 组内累计求和
df['cumulative_salary'] = df.groupby('city')['salary'].cumsum()

透视表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 基本透视表
pivot = pd.pivot_table(df, values='salary', index='city', aggfunc='mean')

# 多维度透视表
pivot = pd.pivot_table(
df,
values='salary',
index='city',
columns='gender',
aggfunc='mean',
fill_value=0,
margins=True, # 添加总计
margins_name='总计'
)

# 多个聚合函数
pivot = pd.pivot_table(
df,
values='salary',
index='city',
aggfunc=['mean', 'sum', 'count']
)

交叉表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 基本交叉表
cross = pd.crosstab(df['city'], df['gender'])

# 带值的交叉表
cross = pd.crosstab(
df['city'],
df['gender'],
values=df['salary'],
aggfunc='mean'
)

# 标准化交叉表
cross = pd.crosstab(df['city'], df['gender'], normalize='index') # 按行标准化
cross = pd.crosstab(df['city'], df['gender'], normalize='columns') # 按列标准化
cross = pd.crosstab(df['city'], df['gender'], normalize='all') # 全局标准化

数据合并

join 连接

1
2
3
4
5
6
7
8
9
10
11
# 基于索引连接
df1 = pd.DataFrame({'A': [1, 2, 3]}, index=['a', 'b', 'c'])
df2 = pd.DataFrame({'B': [4, 5, 6]}, index=['b', 'c', 'd'])

df_joined = df1.join(df2, how='inner')
df_joined = df1.join(df2, how='left')
df_joined = df1.join(df2, how='right')
df_joined = df1.join(df2, how='outer')

# 多个 DataFrame 连接
df_joined = df1.join([df2, df3], how='outer')

concat 连接

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 垂直连接追加行
df1 = pd.DataFrame({'A': [1, 2], 'B': [3, 4]})
df2 = pd.DataFrame({'A': [5, 6], 'B': [7, 8]})
df_concat = pd.concat([df1, df2])
df_concat = pd.concat([df1, df2], ignore_index=True) # 重置索引

# 水平连接追加列
df3 = pd.DataFrame({'C': [9, 10], 'D': [11, 12]})
df_concat = pd.concat([df1, df3], axis=1)

# 多个 DataFrame 连接
df_concat = pd.concat([df1, df2, df3], ignore_index=True)

# 指定键
df_concat = pd.concat([df1, df2], keys=['df1', 'df2'])

# 连接方式
df_concat = pd.concat([df1, df2], join='inner') # 内连接只保留共同列
df_concat = pd.concat([df1, df2], join='outer') # 外连接保留所有列默认

append 追加

1
2
3
4
5
# 旧方法已弃用
df_appended = df1.append(df2)

# 新方法推荐
df_appended = pd.concat([df1, df2], ignore_index=True)

merge 合并

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# 创建示例数据
df_left = pd.DataFrame({
'key': ['A', 'B', 'C', 'D'],
'value_left': [1, 2, 3, 4]
})

df_right = pd.DataFrame({
'key': ['B', 'D', 'E', 'F'],
'value_right': [5, 6, 7, 8]
})

# 内连接只保留匹配的
df_merged = pd.merge(df_left, df_right, on='key', how='inner')

# 左连接保留左表所有
df_merged = pd.merge(df_left, df_right, on='key', how='left')

# 右连接保留右表所有
df_merged = pd.merge(df_left, df_right, on='key', how='right')

# 外连接保留所有
df_merged = pd.merge(df_left, df_right, on='key', how='outer')

# 多键合并
df_merged = pd.merge(df_left, df_right, on=['key1', 'key2'])

# 不同列名合并
df_merged = pd.merge(
df_left, df_right,
left_on='key_left',
right_on='key_right'
)

# 指示符合并
df_merged = pd.merge(df_left, df_right, on='key', how='outer', indicator=True)
print(df_merged['_merge']) # both, left_only, right_only

时间序列

创建时间序列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 创建日期范围
dates = pd.date_range('2024-01-01', periods=10)
dates = pd.date_range('2024-01-01', '2024-01-10')
dates = pd.date_range('2024-01-01', periods=10, freq='D') # 天
dates = pd.date_range('2024-01-01', periods=10, freq='M') # 月
dates = pd.date_range('2024-01-01', periods=10, freq='W') # 周

# 常用频率
# D: 天, W: 周, M: 月末, MS: 月初
# H: 小时, T/min: 分钟, S: 秒
# B: 工作日, BM: 工作月末

# 创建时间序列 DataFrame
df_time = pd.DataFrame({
'date': pd.date_range('2024-01-01', periods=100),
'value': np.random.randn(100)
})
df_time = df_time.set_index('date')

时间转换

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 字符串转时间
df['date'] = pd.to_datetime(df['date_string'])
df['date'] = pd.to_datetime(df['date_string'], format='%Y-%m-%d')
df['date'] = pd.to_datetime(df['date_string'], errors='coerce') # 错误变NaT

# 时间转字符串
df['date_str'] = df['date'].dt.strftime('%Y-%m-%d')
df['date_str'] = df['date'].dt.strftime('%Y年%m月%d日')

# 提取时间组件
df['year'] = df['date'].dt.year
df['month'] = df['date'].dt.month
df['day'] = df['date'].dt.day
df['hour'] = df['date'].dt.hour
df['minute'] = df['date'].dt.minute
df['weekday'] = df['date'].dt.weekday # 0=周一
df['dayofweek'] = df['date'].dt.dayofweek
df['day_name'] = df['date'].dt.day_name()
df['month_name'] = df['date'].dt.month_name()
df['quarter'] = df['date'].dt.quarter
df['is_month_start'] = df['date'].dt.is_month_start
df['is_month_end'] = df['date'].dt.is_month_end

时间索引操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# 设置时间索引
df = df.set_index('date')

# 时间选择
df['2024-01'] # 选择2024年1月
df['2024-01-01'] # 选择特定日期
df['2024-01-01':'2024-01-31'] # 时间范围
df.loc['2024-01'] # 使用loc

# 重采样
df_daily = df.resample('D').mean() # 按天
df_weekly = df.resample('W').mean() # 按周
df_monthly = df.resample('M').mean() # 按月

# 重采样聚合
df_resampled = df.resample('M').agg({
'value': ['mean', 'sum', 'count']
})

# 移动窗口
df['rolling_mean'] = df['value'].rolling(window=7).mean() # 7天移动平均
df['rolling_std'] = df['value'].rolling(window=7).std()
df['expanding_mean'] = df['value'].expanding().mean() # 累积平均

时区处理

1
2
3
4
5
6
7
8
# 本地化时区
df['date'] = df['date'].dt.tz_localize('Asia/Shanghai')

# 转换时区
df['date_utc'] = df['date'].dt.tz_convert('UTC')

# 移除时区
df['date_naive'] = df['date'].dt.tz_localize(None)

进阶功能

多级索引

💗💗 创建多级索引

1
2
3
4
5
6
7
8
9
10
# 从现有 DataFrame 创建
df_multi = df.set_index(['city', 'gender'])

# 从外部创建
arrays = [
['A', 'A', 'B', 'B'],
['one', 'two', 'one', 'two']
]
index = pd.MultiIndex.from_arrays(arrays, names=['first', 'second'])
df_multi = pd.DataFrame(np.random.randn(4, 2), index=index, columns=['col1', 'col2'])

💗💗 多级索引操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 选择数据
df_multi.loc['A']
df_multi.loc['A', 'one']
df_multi.loc[('A', 'one')]
df_multi.loc[pd.IndexSlice['A', :], :]

# 交换层级
df_swapped = df_multi.swaplevel(0, 1)

# 排序索引
df_sorted = df_multi.sort_index(level=0)

# 重置索引
df_reset = df_multi.reset_index()
df_reset = df_multi.reset_index(level=1) # 只重置某层

窗口函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 滚动窗口
df['rolling_mean'] = df['value'].rolling(window=5).mean()
df['rolling_sum'] = df['value'].rolling(window=5).sum()
df['rolling_std'] = df['value'].rolling(window=5).std()

# 最小周期
df['rolling_mean'] = df['value'].rolling(window=5, min_periods=3).mean()

# 指数加权移动平均
df['ewm_mean'] = df['value'].ewm(span=5).mean()

# 扩展窗口累积
df['cumulative_mean'] = df['value'].expanding().mean()
df['cumulative_sum'] = df['value'].expanding().sum()

分类数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 创建分类数据
df['grade'] = pd.Categorical(['A', 'B', 'C', 'A', 'B'], categories=['A', 'B', 'C', 'D'])

# 有序分类
df['size'] = pd.Categorical(['S', 'M', 'L', 'XL'], categories=['S', 'M', 'L', 'XL'], ordered=True)

# 分类操作
print(df['grade'].cat.categories) # 查看类别
print(df['grade'].cat.codes) # 查看编码

# 添加类别
df['grade'] = df['grade'].cat.add_categories(['E'])

# 移除类别
df['grade'] = df['grade'].cat.remove_categories(['D'])

# 重命名类别
df['grade'] = df['grade'].cat.rename_categories({'A': '优秀', 'B': '良好'})

# 重新排序
df['size'] = df['size'].cat.reorder_categories(['XL', 'L', 'M', 'S'])

稀疏数据

1
2
3
4
5
6
7
8
9
10
11
12
# 创建稀疏 Series
s_sparse = pd.Series([0, 0, 1, 0, 0, 2], dtype=pd.SparseDtype("int", fill_value=0))

# 转换为稀疏
df_sparse = df.astype(pd.SparseDtype("float", fill_value=0))

# 查看稀疏信息
print(df_sparse.sparse.density) # 密度
print(df_sparse.sparse.fill_value) # 填充值

# 转回密集
df_dense = df_sparse.sparse.to_dense()

性能优化

代码规范

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# 1. 使用有意义的变量名
# 不好
df1 = pd.read_csv('data.csv')
df2 = df1[df1['age'] > 25]

# 好
customers = pd.read_csv('customers.csv')
adult_customers = customers[customers['age'] > 25]

# 2. 链式操作要清晰
# 不好
result = df[df['age'] > 25].groupby('city')['salary'].mean().reset_index()

# 好
result = (
df[df['age'] > 25]
.groupby('city')['salary']
.mean()
.reset_index()
)

# 3. 使用 copy 避免 SettingWithCopyWarning
df_subset = df[df['age'] > 25].copy()
df_subset['new_col'] = df_subset['salary'] * 0.1

# 4. 及时释放内存
del large_df
import gc
gc.collect()

性能建议

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# 1. 读取时指定数据类型
df = pd.read_csv('data.csv', dtype={
'id': 'int32',
'category': 'category',
'amount': 'float32'
})

# 2. 只读取需要的列
df = pd.read_csv('data.csv', usecols=['id', 'name', 'age'])

# 3. 使用合适的文件格式
# Parquet 比 CSV 更快更小
df.to_parquet('data.parquet')
df = pd.read_parquet('data.parquet')

# 4. 避免迭代行
# 不好
for index, row in df.iterrows():
print(row['name'])

# 好
for name in df['name']:
print(name)

# 或使用 itertuples更快
for row in df.itertuples():
print(row.name)

常见陷阱

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 1. 链式索引问题
# 可能产生警告
df[df['age'] > 25]['salary'] = 10000

# 正确做法
df.loc[df['age'] > 25, 'salary'] = 10000

# 2. 视图 vs 副本
df_subset = df[df['age'] > 25] # 可能是视图
df_subset = df[df['age'] > 25].copy() # 明确创建副本

# 3. 缺失值比较
# 错误
if df['value'] == np.nan:
pass

# 正确
if pd.isna(df['value']):
pass

# 4. 日期解析
# 让 Pandas 自动解析
df = pd.read_csv('data.csv', parse_dates=['date_column'])

# 或手动指定格式
df['date'] = pd.to_datetime(df['date_str'], format='%Y-%m-%d')

数据质量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 问题缺失值过多
# 解决方案
# 1. 分析缺失模式
print(df.isnull().sum())
print(df.isnull().mean())

# 2. 选择合适的填充策略
df_filled = df.fillna(df.mean()) # 均值填充
df_filled = df.fillna(method='ffill') # 前向填充

# 3. 删除缺失过多的列
df_clean = df.dropna(thresh=len(df)*0.8, axis=1)

# 问题数据类型不一致
# 解决方案
# 1. 统一数据类型
df['column'] = pd.to_numeric(df['column'], errors='coerce')

# 2. 处理异常值
Q1 = df['column'].quantile(0.25)
Q3 = df['column'].quantile(0.75)
df_clean = df[(df['column'] >= Q1 - 1.5*IQR) & (df['column'] <= Q3 + 1.5*IQR)]

内存优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# 查看内存使用
print(df.memory_usage(deep=True))
print(df.memory_usage(deep=True).sum() / 1024**2, 'MB')

# 优化整数类型
df['small_int'] = df['small_int'].astype('int8') # -128 to 127
df['medium_int'] = df['medium_int'].astype('int16') # -32768 to 32767
df['large_int'] = df['large_int'].astype('int32') # -2^31 to 2^31-1

# 优化浮点类型
df['float32_col'] = df['float64_col'].astype('float32')

# 使用分类类型
df['category_col'] = df['category_col'].astype('category')

# 优化字符串
df['string_col'] = df['string_col'].astype('string') # Pandas 1.0+

# 批量优化
def optimize_dataframe(df):
start_mem = df.memory_usage(deep=True).sum() / 1024**2

for col in df.columns:
col_type = df[col].dtype

if col_type != 'object':
c_min = df[col].min()
c_max = df[col].max()

if str(col_type)[:3] == 'int':
if c_min >= np.iinfo(np.int8).min and c_max <= np.iinfo(np.int8).max:
df[col] = df[col].astype(np.int8)
elif c_min >= np.iinfo(np.int16).min and c_max <= np.iinfo(np.int16).max:
df[col] = df[col].astype(np.int16)
elif c_min >= np.iinfo(np.int32).min and c_max <= np.iinfo(np.int32).max:
df[col] = df[col].astype(np.int32)
else:
df[col] = df[col].astype(np.int64)
else:
if c_min >= np.finfo(np.float16).min and c_max <= np.finfo(np.float16).max:
df[col] = df[col].astype(np.float16)
elif c_min >= np.finfo(np.float32).min and c_max <= np.finfo(np.float32).max:
df[col] = df[col].astype(np.float32)
else:
df[col] = df[col].astype(np.float64)
else:
df[col] = df[col].astype('category')

end_mem = df.memory_usage(deep=True).sum() / 1024**2
print(f'Memory usage decreased from {start_mem:.2f} MB to {end_mem:.2f} MB')
print(f'Reduction: {100 * (start_mem - end_mem) / start_mem:.1f}%')

return df

df_optimized = optimize_dataframe(df)

向量化操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 避免使用循环
# 慢使用循环
result = []
for value in df['column']:
result.append(value * 2)
df['new_column'] = result

# 快向量化操作
df['new_column'] = df['column'] * 2

# 避免使用 apply除非必要
# 较慢
df['new'] = df['col'].apply(lambda x: x * 2)

# 较快
df['new'] = df['col'] * 2

# 使用 NumPy 函数
df['log_value'] = np.log(df['value'])
df['sqrt_value'] = np.sqrt(df['value'])

分块处理大数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 分块读取
chunk_size = 10000
chunks = pd.read_csv('large_file.csv', chunksize=chunk_size)

results = []
for chunk in chunks:
# 处理每个块
processed = chunk[chunk['value'] > 100]
results.append(processed)

# 合并结果
df_result = pd.concat(results, ignore_index=True)

# 分块写入
writer = pd.ExcelWriter('output.xlsx')
for i, chunk in enumerate(chunks):
chunk.to_excel(writer, sheet_name=f'Sheet{i}', index=False)
writer.close()

SQLAlchemy

SQLAlchemy 是 Python 中最流行的 ORM对象关系映射框架它提供了企业级持久化模式用于高效和高性能的数据库访问

  • 核心功能
    • ORM 映射将 Python 类与数据库表进行映射
    • SQL 表达式语言提供 Pythonic 的 SQL 构建方式
    • 会话管理事务管理和连接池
    • 迁移支持配合 Alembic 实现数据库版本控制
  • 主要优势
    • 灵活性高支持 ORM 和 Core 两种使用方式
    • 性能优秀接近原生 SQL 的性能
    • 数据库无关支持多种数据库后端
    • 社区活跃丰富的插件和扩展生态

核心组件

  • Engine
    • 数据库连接的核心接口
    • 管理连接池和方言
  • Session
    • 工作单元模式跟踪所有对象变化
    • 提供事务管理
  • Model/Declarative Base
    • 定义数据库表的 Python 类
    • SQLAlchemy 自动生成 SQL
  • Query
    • 构建和执行查询的对象
    • 支持链式调用

工作原理

执行流程

1
2
3
4
5
6
1. 创建 Engine → 配置数据库连接
2. 定义 Model 类 → 映射数据库表
3. 创建 Session → 管理工作单元
4. 通过 Session 执行 CRUD 操作
5. 提交事务或回滚
6. 关闭 Session

SQL 生成过程

1
2
3
4
5
6
7
Python 对象操作 → ORM 转换 → SQL 表达式 → 编译为 SQL → 执行

关键步骤
- 映射将 Python 类映射到数据库表
- 跟踪Session 跟踪对象状态变化
- 刷新将变化同步到数据库
- 提交持久化所有更改

环境搭建

1
2
3
4
5
6
7
8
9
10
11
# 基础安装
pip install sqlalchemy

# 安装特定数据库驱动
pip install pymysql # MySQL
pip install psycopg2 # PostgreSQL
pip install cx_Oracle # Oracle
pip install sqlite3 # SQLitePython 内置

# 完整安装推荐
pip install sqlalchemy[pymysql]

💗💗 requirements.txt

1
2
3
sqlalchemy==2.0.23
pymysql==1.1.0
alembic==1.13.0

基本使用

创建数据库表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
-- 创建用户表
CREATE TABLE `user` (
`id` INT PRIMARY KEY AUTO_INCREMENT,
`username` VARCHAR(50) NOT NULL,
`password` VARCHAR(100) NOT NULL,
`email` VARCHAR(100),
`age` INT,
`create_time` DATETIME DEFAULT CURRENT_TIMESTAMP,
`update_time` DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

-- 插入测试数据
INSERT INTO `user` (`username`, `password`, `email`, `age`) VALUES
('张三', '123456', 'zhangsan@example.com', 25),
('李四', '123456', 'lisi@example.com', 30),
('王五', '123456', 'wangwu@example.com', 28);

创建模型类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from sqlalchemy import create_engine, Column, Integer, String, DateTime
from sqlalchemy.orm import declarative_base, sessionmaker
from datetime import datetime

# 创建基类
Base = declarative_base()

# 定义用户模型
class User(Base):
__tablename__ = 'user'

id = Column(Integer, primary_key=True, autoincrement=True)
username = Column(String(50), nullable=False)
password = Column(String(100), nullable=False)
email = Column(String(100))
age = Column(Integer)
create_time = Column(DateTime, default=datetime.now)
update_time = Column(DateTime, default=datetime.now, onupdate=datetime.now)

def __repr__(self):
return f"<User(id={self.id}, username='{self.username}')>"

def to_dict(self):
return {
'id': self.id,
'username': self.username,
'email': self.email,
'age': self.age
}

数据库连接

💗💗 第一种连接方式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from sqlalchemy import create_engine, text
from sqlalchemy.orm import Session
from urllib.parse import quote_plus

# 创建引擎
# 注意密码中的 @ 符号需要 URL 编码或者使用 quote_plus
DB_USER = "root"
DB_PASSWORD = quote_plus("123456@") # 假设密码包含 @
DB_HOST = "localhost"
DB_PORT = 3306
DB_NAME = "sqlalchemy_demo"

# 建议从环境变量获取密码避免硬编码
# DB_PASSWORD = quote_plus(os.getenv("DB_PASSWORD", "default_pass"))

DATABASE_URL = f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"

engine = create_engine(
DATABASE_URL,
echo=False, # 生产环境建议关闭 echo或改用 logging
pool_size=5,
pool_recycle=3600,
pool_pre_ping=True # 建议开启自动检测断连并重连
)


def get_raw_data():
# 使用 with 确保连接释放
with Session(engine) as session:
try:
# 使用 text() 包装原生 SQL
result = session.execute(text("SELECT * FROM project"))

# 获取所有行
rows = result.fetchall()

for row in rows:
print(row)

return rows
except Exception as e:
session.rollback()
raise e

get_raw_data()

💗💗 第二种配置文件方式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# config.py
import os

DATABASE_CONFIG = {
'driver': os.getenv('DB_DRIVER', 'mysql+pymysql'),
'username': os.getenv('DB_USERNAME', 'root'),
'password': os.getenv('DB_PASSWORD', '123456'),
'host': os.getenv('DB_HOST', 'localhost'),
'port': os.getenv('DB_PORT', '3306'),
'database': os.getenv('DB_NAME', 'sqlalchemy_demo')
}

def get_database_url():
return "{driver}://{username}:{password}@{host}:{port}/{database}".format(
**DATABASE_CONFIG
)

创建数据库表

1
2
3
4
5
6
7
8
9
10
11
from sqlalchemy import create_engine
from models import Base, User

# 创建引擎
engine = create_engine('mysql+pymysql://root:123456@localhost:3306/sqlalchemy_demo')

# 创建所有表
Base.metadata.create_all(engine)

# 删除所有表谨慎使用
# Base.metadata.drop_all(engine)

CRUD 操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from sqlalchemy.orm import Session
from models import User, SessionLocal

# 创建会话
session = SessionLocal()

try:
# 插入数据
new_user = User(username='赵六', password='123456', email='zhaoliu@example.com', age=26)
session.add(new_user)
session.commit()
print(f"新用户 ID: {new_user.id}")

# 查询数据
user = session.query(User).filter_by(username='张三').first()
print(f"查询结果: {user}")

# 更新数据
user.age = 26
session.commit()

# 删除数据
session.delete(user)
session.commit()

finally:
session.close()

基本查询

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from sqlalchemy import and_, or_

# 查询所有记录
users = session.query(User).all()

# 根据主键查询
user = session.query(User).get(1)

# 条件查询
user = session.query(User).filter_by(username='张三').first()

# 多条件查询
users = session.query(User).filter(
and_(User.age > 20, User.age < 30)
).all()

# OR 条件
users = session.query(User).filter(
or_(User.username == '张三', User.username == '李四')
).all()

模糊查询

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# LIKE 查询
users = session.query(User).filter(
User.username.like('%张%')
).all()

# 不区分大小写
users = session.query(User).filter(
User.username.ilike('%zhang%')
).all()

# 正则表达式MySQL
from sqlalchemy import func
users = session.query(User).filter(
User.username.regexp_match('^张.*')
).all()

排序分页

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 排序
users = session.query(User).order_by(User.age.desc()).all()
users = session.query(User).order_by(User.username.asc(), User.age.desc()).all()

# 限制返回数量
users = session.query(User).limit(10).all()

# 偏移量
users = session.query(User).offset(20).limit(10).all()

# 分页查询
page = 2
per_page = 10
users = session.query(User).offset((page - 1) * per_page).limit(per_page).all()

聚合查询

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from sqlalchemy import func

# 计数
count = session.query(func.count(User.id)).scalar()

# 平均值
avg_age = session.query(func.avg(User.age)).scalar()

# 最大值/最小值
max_age = session.query(func.max(User.age)).scalar()
min_age = session.query(func.min(User.age)).scalar()

# 分组统计
from sqlalchemy import func
result = session.query(
User.age,
func.count(User.id).label('count')
).group_by(User.age).all()

for age, count in result:
print(f"年龄 {age}: {count} 人")

单条插入

1
2
3
4
5
6
7
8
9
10
11
12
13
# 方式一add
user = User(username='测试用户', password='123456', email='test@example.com', age=25)
session.add(user)
session.commit()

# 方式二add_all批量
users = [
User(username='用户1', password='123456', age=20),
User(username='用户2', password='123456', age=22),
User(username='用户3', password='123456', age=24)
]
session.add_all(users)
session.commit()

批量插入

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 方式一bulk_insert_mappings高性能
user_dicts = [
{'username': '用户1', 'password': '123456', 'age': 20},
{'username': '用户2', 'password': '123456', 'age': 22},
{'username': '用户3', 'password': '123456', 'age': 24}
]
session.bulk_insert_mappings(User, user_dicts)
session.commit()

# 方式二execute + insert
from sqlalchemy import insert
stmt = insert(User).values([
{'username': '用户1', 'password': '123456', 'age': 20},
{'username': '用户2', 'password': '123456', 'age': 22}
])
session.execute(stmt)
session.commit()

单条更新

1
2
3
4
5
6
7
8
9
10
11
12
# 查询后更新
user = session.query(User).get(1)
user.username = '新用户名'
user.email = 'newemail@example.com'
session.commit()

# 直接更新不加载对象
session.query(User).filter(User.id == 1).update({
'username': '新用户名',
'email': 'newemail@example.com'
})
session.commit()

批量更新

1
2
3
4
5
6
7
8
9
10
# 批量更新所有符合条件的记录
session.query(User).filter(User.age < 18).update({
'age': User.age + 1
}, synchronize_session='fetch')
session.commit()

# synchronize_session 参数
# - False: 不同步会话最快
# - 'fetch': 更新前获取对象
# - 'evaluate': 评估表达式

单条删除

1
2
3
4
5
6
7
8
# 查询后删除
user = session.query(User).get(1)
session.delete(user)
session.commit()

# 直接删除
session.query(User).filter(User.id == 1).delete()
session.commit()

批量删除

1
2
3
4
# 删除所有符合条件的记录
deleted_count = session.query(User).filter(User.age < 18).delete()
session.commit()
print(f"删除了 {deleted_count} 条记录")

高级查询

一对一关系

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from sqlalchemy import ForeignKey
from sqlalchemy.orm import relationship

# 定义模型
class User(Base):
__tablename__ = 'user'

id = Column(Integer, primary_key=True)
username = Column(String(50))

# 一对一关系
profile = relationship("Profile", back_populates="user", uselist=False)

class Profile(Base):
__tablename__ = 'profile'

id = Column(Integer, primary_key=True)
bio = Column(String(200))
user_id = Column(Integer, ForeignKey('user.id'), unique=True)

# 反向关系
user = relationship("User", back_populates="profile")

# 查询
user = session.query(User).options(joinedload(User.profile)).filter_by(username='张三').first()
print(user.profile.bio)

一对多关系

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class User(Base):
__tablename__ = 'user'

id = Column(Integer, primary_key=True)
username = Column(String(50))

# 一对多关系
orders = relationship("Order", back_populates="user")

class Order(Base):
__tablename__ = 'order'

id = Column(Integer, primary_key=True)
order_no = Column(String(50))
user_id = Column(Integer, ForeignKey('user.id'))

# 反向关系
user = relationship("User", back_populates="orders")

# 查询用户及其订单
user = session.query(User).options(joinedload(User.orders)).get(1)
for order in user.orders:
print(order.order_no)

# 查询有订单的用户
users = session.query(User).join(Order).all()

多对多关系

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# 关联表
student_course = Table('student_course', Base.metadata,
Column('student_id', Integer, ForeignKey('student.id')),
Column('course_id', Integer, ForeignKey('course.id'))
)

class Student(Base):
__tablename__ = 'student'

id = Column(Integer, primary_key=True)
name = Column(String(50))

# 多对多关系
courses = relationship("Course", secondary=student_course, back_populates="students")

class Course(Base):
__tablename__ = 'course'

id = Column(Integer, primary_key=True)
name = Column(String(50))

# 反向关系
students = relationship("Student", secondary=student_course, back_populates="courses")

# 查询学生及其课程
student = session.query(Student).options(joinedload(Student.courses)).get(1)
for course in student.courses:
print(course.name)

子查询

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from sqlalchemy import func

# 子查询作为条件
subquery = session.query(User.id).filter(User.age > 25).subquery()
users = session.query(User).filter(User.id.in_(subquery)).all()

# 子查询作为列
subquery = session.query(
Order.user_id,
func.count(Order.id).label('order_count')
).group_by(Order.user_id).subquery()

users = session.query(User, subquery.c.order_count).outerjoin(
subquery, User.id == subquery.c.user_id
).all()

原生 SQL

1
2
3
4
5
6
7
8
9
10
from sqlalchemy import text

# 执行原生 SQL
result = session.execute(text("SELECT * FROM user WHERE age > :age"), {"age": 25})
users = result.fetchall()

# 原生 SQL 插入
session.execute(text("INSERT INTO user (username, password) VALUES (:username, :password)"),
{"username": "测试", "password": "123456"})
session.commit()

会话管理

生命周期

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from contextlib import contextmanager

# 方式一手动管理
session = SessionLocal()
try:
# 执行操作
user = session.query(User).get(1)
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()

# 方式二上下文管理器
@contextmanager
def get_session():
session = SessionLocal()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()

# 使用
with get_session() as session:
user = session.query(User).get(1)

事务管理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 自动提交
session = SessionLocal(autocommit=False, autoflush=True)

try:
user1 = User(username='用户1', password='123456')
user2 = User(username='用户2', password='123456')

session.add(user1)
session.add(user2)

# 显式提交
session.commit()

except Exception as e:
# 发生错误时回滚
session.rollback()
print(f"事务回滚: {e}")

连接池配置

1
2
3
4
5
6
7
8
9
10
from sqlalchemy import create_engine

engine = create_engine(
'mysql+pymysql://root:123456@localhost:3306/db',
pool_size=10, # 连接池大小
max_overflow=20, # 超出 pool_size 后最多创建的连接数
pool_timeout=30, # 获取连接的超时时间
pool_recycle=3600, # 连接回收时间
pool_pre_ping=True # 使用前检查连接是否有效
)

Flask 整合

安装

1
pip install flask-sqlalchemy

配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from flask import Flask
from flask_sqlalchemy import SQLAlchemy

app = Flask(__name__)

# 配置数据库
app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql+pymysql://root:123456@localhost:3306/flask_demo'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
app.config['SQLALCHEMY_ECHO'] = True

# 初始化
db = SQLAlchemy(app)

# 定义模型
class User(db.Model):
__tablename__ = 'user'

id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(50), nullable=False)
email = db.Column(db.String(100))
age = db.Column(db.Integer)

def to_dict(self):
return {
'id': self.id,
'username': self.username,
'email': self.email,
'age': self.age
}

# 创建表
with app.app_context():
db.create_all()

路由

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from flask import jsonify, request

@app.route('/users', methods=['GET'])
def get_users():
users = User.query.all()
return jsonify([user.to_dict() for user in users])

@app.route('/users/<int:id>', methods=['GET'])
def get_user(id):
user = User.query.get_or_404(id)
return jsonify(user.to_dict())

@app.route('/users', methods=['POST'])
def create_user():
data = request.json
user = User(
username=data['username'],
email=data.get('email'),
age=data.get('age')
)
db.session.add(user)
db.session.commit()
return jsonify(user.to_dict()), 201

@app.route('/users/<int:id>', methods=['PUT'])
def update_user(id):
user = User.query.get_or_404(id)
data = request.json

user.username = data.get('username', user.username)
user.email = data.get('email', user.email)
user.age = data.get('age', user.age)

db.session.commit()
return jsonify(user.to_dict())

@app.route('/users/<int:id>', methods=['DELETE'])
def delete_user(id):
user = User.query.get_or_404(id)
db.session.delete(user)
db.session.commit()
return '', 204

整合

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from fastapi import FastAPI, Depends, HTTPException
from sqlalchemy.orm import Session
from pydantic import BaseModel

app = FastAPI()

# 依赖注入
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

# Pydantic 模型
class UserCreate(BaseModel):
username: str
email: str = None
age: int = None

class UserResponse(BaseModel):
id: int
username: str
email: str = None
age: int = None

class Config:
from_attributes = True

# API 路由
@app.get("/users", response_model=list[UserResponse])
def get_users(db: Session = Depends(get_db)):
return db.query(User).all()

@app.get("/users/{user_id}", response_model=UserResponse)
def get_user(user_id: int, db: Session = Depends(get_db)):
user = db.query(User).get(user_id)
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
return user

@app.post("/users", response_model=UserResponse)
def create_user(user: UserCreate, db: Session = Depends(get_db)):
db_user = User(**user.dict())
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user

性能优化

命名规范

1
2
3
4
5
6
7
8
9
10
模型类命名
- 类名大驼峰PascalCase如 UserUserProfile
- 表名小写复数如 usersuser_profiles
- 字段名小写下划线如 user_namecreate_time

文件组织
- models.py: 模型定义
- database.py: 数据库配置
- schemas.py: Pydantic 模式
- crud.py: CRUD 操作

代码规范

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# 1. 使用上下文管理器
with SessionLocal() as session:
user = session.query(User).get(1)

# 2. 异常处理
try:
session.add(user)
session.commit()
except Exception as e:
session.rollback()
logger.error(f"数据库操作失败: {e}")
raise

# 3. 类型提示
from typing import List, Optional

def get_users(session: Session, limit: int = 10) -> List[User]:
return session.query(User).limit(limit).all()

# 4. 文档字符串
def create_user(session: Session, username: str, email: str) -> User:
"""
创建新用户

Args:
session: 数据库会话
username: 用户名
email: 邮箱

Returns:
创建的用户对象

Raises:
IntegrityError: 用户名已存在
"""
user = User(username=username, email=email)
session.add(user)
session.commit()
return user

性能优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# 1. 避免 N+1 查询问题
# 错误示例
users = session.query(User).all()
for user in users:
print(user.profile.bio) # 每次循环都执行一次查询

# 正确示例使用 joinedload
from sqlalchemy.orm import joinedload
users = session.query(User).options(joinedload(User.profile)).all()
for user in users:
print(user.profile.bio) # 只执行一次查询

# 2. 只查询需要的字段
users = session.query(User.id, User.username).all()

# 3. 使用延迟加载
from sqlalchemy.orm import deferred
class User(Base):
large_data = deferred(Column(Text)) # 只在访问时加载

# 4. 批量操作
session.bulk_insert_mappings(User, user_list)

# 5. 合理使用索引
class User(Base):
__table_args__ = (
Index('idx_username', 'username'),
Index('idx_email', 'email'),
)

安全建议

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# 1. 防止 SQL 注入使用参数化查询
# 正确
user = session.query(User).filter(User.username == username).first()

# 错误不要拼接 SQL
# session.execute(f"SELECT * FROM user WHERE username = '{username}'")

# 2. 密码加密
from werkzeug.security import generate_password_hash, check_password_hash

class User(Base):
password_hash = Column(String(128))

def set_password(self, password):
self.password_hash = generate_password_hash(password)

def check_password(self, password):
return check_password_hash(self.password_hash, password)

# 3. 输入验证
from pydantic import BaseModel, validator

class UserCreate(BaseModel):
username: str
email: str

@validator('username')
def validate_username(cls, v):
if len(v) < 3 or len(v) > 50:
raise ValueError('用户名长度必须在 3-50 之间')
return v

常见问题

中文乱码问题

1
2
3
4
5
6
7
8
# 确保数据库 URL 包含编码参数
engine = create_engine(
'mysql+pymysql://root:123456@localhost:3306/db?charset=utf8mb4',
encoding='utf-8'
)

# 创建数据库时指定字符集
# CREATE DATABASE dbname CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;

日期时间处理

1
2
3
4
5
6
7
8
9
10
from datetime import datetime
from sqlalchemy import Column, DateTime

class User(Base):
create_time = Column(DateTime, default=datetime.now)
update_time = Column(DateTime, default=datetime.now, onupdate=datetime.now)

# 使用时区感知的时间
from datetime import timezone
create_time = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))

循环导入问题

1
2
3
4
5
6
7
8
9
10
11
12
# 问题两个模型互相引用导致循环导入
# 解决方案一使用字符串引用
class User(Base):
orders = relationship("Order", back_populates="user")

class Order(Base):
user = relationship("User", back_populates="orders")

# 解决方案二延迟导入
def get_relationships():
from models import Order
return relationship(Order)

报错处理

💗💗 SQLAlchemy 报错InvalidRequestError

1
2
3
4
5
6
7
8
9
10
错误信息
Object '<User at 0x...>' is already attached to session '...'

错误原因
对象已经被添加到另一个 Session

解决方案
1. 确保每个对象只属于一个 Session
2. 使用 session.merge() 合并对象
3. 从原 Session 中 expunge 对象后再添加到新 Session

💗💗 SQLAlchemy 报错DetachedInstanceError

1
2
3
4
5
6
7
8
9
10
错误信息
Instance '<User at 0x...>' is not bound to a Session

错误原因
Session 已关闭但仍在访问对象的懒加载属性

解决方案
1. 在 Session 关闭前访问所有需要的属性
2. 使用 eager loadingjoinedload预加载关系
3. 延长 Session 的生命周期

💗💗 SQLAlchemy 报错IntegrityError

1
2
3
4
5
6
7
8
9
10
11
错误信息
(sqlite3.IntegrityError) UNIQUE constraint failed: user.username

错误原因
违反唯一约束或外键约束

解决方案
1. 检查是否有重复数据
2. 使用 try-except 捕获异常
3. 在插入前检查数据是否存在
4. 使用 upsert 模式先查询再决定插入或更新

💗💗 SQLAlchemy 报错OperationalError

1
2
3
4
5
6
7
8
9
10
11
错误信息
Can't connect to MySQL server on 'localhost'

错误原因
数据库连接失败

解决方案
1. 检查数据库服务是否启动
2. 验证连接参数主机端口用户名密码
3. 检查防火墙设置
4. 确认数据库是否存在

Matplotlib

Matplotlib 是 Python 中最流行的数据可视化库提供了丰富的绘图功能和高度可定制的图表样式它支持多种图表类型能够生成高质量的静态动态和交互式可视化作品

  • 核心功能
    • 多种图表线图柱状图散点图饼图直方图等
    • 高度定制颜色样式标签图例等全面控制
    • 多格式输出PNGPDFSVGEPS 等多种格式
    • 子图布局支持复杂的图表布局和组合
    • 3D 绘图通过 mplot3d 工具包支持三维可视化
  • 主要优势
    • 成熟稳定历史悠久社区活跃文档完善
    • 灵活强大几乎可以绘制任何类型的图表
    • 生态集成与 NumPyPandasSciPy 无缝配合
    • 跨平台支持 WindowsLinuxmacOS

核心组件

  • Figure画布
    • 整个图表的容器
    • 可以包含多个 Axes子图
  • Axes坐标系/子图
    • 实际的绘图区域
    • 包含坐标轴标题图例等元素
  • Axis坐标轴
    • X 轴和 Y 轴
    • 控制刻度标签范围等
  • Artist艺术家对象
    • 所有可见元素的基类
    • 包括线条文本图例等

工作原理

绘图流程

1
2
3
4
5
1. 创建 Figure 对象 → 设置画布大小和分辨率
2. 添加 Axes 对象 → 创建子图
3. 绘制数据 → 调用 plotbarscatter 等方法
4. 定制样式 → 设置标题标签图例颜色等
5. 显示或保存 → plt.show() 或 fig.savefig()

渲染机制

1
2
3
4
5
6
7
数据准备 → Artist 对象创建 → 渲染引擎处理 → 输出图像

关键步骤
- 数据转换将数据转换为 Artist 对象
- 布局计算计算各元素位置和大小
- 渲染使用后端渲染引擎绘制
- 输出保存为文件或显示在屏幕上

环境搭建

pip 安装

1
2
3
4
5
6
7
8
9
10
11
# 基础安装
pip install matplotlib

# 指定版本安装
pip install matplotlib==3.8.2

# 升级 Matplotlib
pip install --upgrade matplotlib

# 安装完整功能包含所有可选依赖
pip install matplotlib[all]

conda 安装

1
2
3
4
5
6
7
8
# 使用 conda 安装
conda install matplotlib

# 指定版本
conda install matplotlib=3.8.2

# 在指定环境中安装
conda install -n myenv matplotlib

验证安装

1
2
3
4
5
6
7
8
9
10
11
12
import matplotlib
print(matplotlib.__version__)
# 输出: 3.8.2

# 查看后端信息
print(matplotlib.get_backend())

# 测试绘图
import matplotlib.pyplot as plt
plt.plot([1, 2, 3], [1, 4, 9])
plt.title('Test Plot')
plt.show()

依赖环境

1
2
3
4
5
6
7
8
9
10
11
12
13
matplotlib 主要依赖
- numpy: 数值计算基础
- pillow: 图像处理
- fonttools: 字体处理
- kiwisolver: 布局求解器
- cycler: 属性循环
- packaging: 版本管理

可选依赖
- pandas: 数据框支持
- scipy: 科学计算
- PyQt5/PySide2: GUI 后端
- LaTeX: 数学公式渲染

💗💗 requirements.txt

1
2
3
4
matplotlib==3.8.2
numpy==1.26.2
pandas==2.1.4
pillow==10.1.0

后端选择

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import matplotlib
import matplotlib.pyplot as plt

# 查看可用后端
print(matplotlib.rcsetup.all_backends)

# 设置后端必须在导入 pyplot 之前
matplotlib.use('Agg') # 非交互式服务器
matplotlib.use('TkAgg') # Tkinter 交互式
matplotlib.use('Qt5Agg') # Qt5 交互式

# Jupyter Notebook 中
%matplotlib inline # 静态显示
%matplotlib notebook # 交互式
%matplotlib widget # 新式交互需要 ipympl

样式配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 查看所有可用样式
print(plt.style.available)

# 使用内置样式
plt.style.use('ggplot')
plt.style.use('seaborn-v0_8')
plt.style.use('dark_background')
plt.style.use('fivethirtyeight')

# 自定义样式
plt.style.use({
'axes.facecolor': '#EAEAF2',
'axes.edgecolor': 'white',
'axes.grid': True,
'grid.color': 'white',
'grid.linestyle': '-',
'font.size': 12,
'lines.linewidth': 2
})

# 恢复默认样式
plt.style.use('default')

中文显示配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import matplotlib.pyplot as plt

# Windows 系统
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei']
plt.rcParams['axes.unicode_minus'] = False

# macOS 系统
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'Heiti TC']
plt.rcParams['axes.unicode_minus'] = False

# Linux 系统
plt.rcParams['font.sans-serif'] = ['WenQuanYi Micro Hei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

# 永久配置修改 matplotlibrc 文件
# 找到配置文件位置
print(matplotlib.matplotlib_fname())

基础绘图

基本线图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import matplotlib.pyplot as plt
import numpy as np

# 简单线图
x = np.linspace(0, 10, 100)
y = np.sin(x)

plt.plot(x, y)
plt.title('Simple Line Plot')
plt.xlabel('X Axis')
plt.ylabel('Y Axis')
plt.show()

# 多条线
y2 = np.cos(x)
plt.plot(x, y, label='sin(x)')
plt.plot(x, y2, label='cos(x)')
plt.legend()
plt.title('Multiple Lines')
plt.show()

样式定制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# 线条样式
plt.plot(x, y, color='red', linewidth=2, linestyle='--')
plt.plot(x, y, color='#FF5733', linewidth=3, linestyle='-.')

# 标记点
plt.plot(x, y, marker='o', markersize=5, markerfacecolor='red')
plt.plot(x, y, marker='s', markeredgecolor='blue', markeredgewidth=2)

# 常用标记符号
# 'o' 圆形, 's' 方形, '^' 三角形, 'D' 菱形
# '*' 星形, '+' 加号, 'x' 叉号

# 完整示例
plt.figure(figsize=(10, 6))
plt.plot(x, y, 'r-o', label='sin(x)', linewidth=2, markersize=6)
plt.plot(x, y2, 'b--s', label='cos(x)', linewidth=2, markersize=6)
plt.xlabel('X', fontsize=12)
plt.ylabel('Y', fontsize=12)
plt.title('Styled Line Plot', fontsize=14, fontweight='bold')
plt.legend(loc='best', fontsize=11)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

基本散点图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 简单散点图
np.random.seed(42)
x = np.random.rand(50)
y = np.random.rand(50)

plt.scatter(x, y)
plt.title('Simple Scatter Plot')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()

# 带颜色和大小
sizes = np.random.rand(50) * 1000
colors = np.random.rand(50)

plt.scatter(x, y, s=sizes, c=colors, alpha=0.6, cmap='viridis')
plt.colorbar(label='Color Value')
plt.title('Colored Scatter Plot')
plt.show()

高级散点图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 分组散点图
group1_x = np.random.randn(50)
group1_y = np.random.randn(50)
group2_x = np.random.randn(50) + 2
group2_y = np.random.randn(50) + 2

plt.scatter(group1_x, group1_y, c='red', label='Group 1', alpha=0.6)
plt.scatter(group2_x, group2_y, c='blue', label='Group 2', alpha=0.6)
plt.legend()
plt.title('Grouped Scatter Plot')
plt.grid(True, alpha=0.3)
plt.show()

# 气泡图
fig, ax = plt.subplots(figsize=(10, 8))
scatter = ax.scatter(x, y, s=sizes*100, c=colors,
alpha=0.6, cmap='plasma', edgecolors='black')
plt.colorbar(scatter, label='Intensity')
ax.set_xlabel('X Axis', fontsize=12)
ax.set_ylabel('Y Axis', fontsize=12)
ax.set_title('Bubble Chart', fontsize=14)
plt.show()

垂直柱状图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 简单柱状图
categories = ['A', 'B', 'C', 'D', 'E']
values = [25, 40, 30, 55, 45]

plt.bar(categories, values, color='steelblue', edgecolor='black')
plt.title('Vertical Bar Chart')
plt.xlabel('Categories')
plt.ylabel('Values')
plt.show()

# 水平柱状图
plt.barh(categories, values, color='coral', edgecolor='black')
plt.title('Horizontal Bar Chart')
plt.xlabel('Values')
plt.ylabel('Categories')
plt.show()

分组和堆叠柱状图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# 分组柱状图
categories = ['Q1', 'Q2', 'Q3', 'Q4']
product_a = [30, 45, 50, 60]
product_b = [25, 40, 55, 65]

x = np.arange(len(categories))
width = 0.35

fig, ax = plt.subplots(figsize=(10, 6))
bars1 = ax.bar(x - width/2, product_a, width, label='Product A', color='#4CAF50')
bars2 = ax.bar(x + width/2, product_b, width, label='Product B', color='#2196F3')

ax.set_xlabel('Quarter')
ax.set_ylabel('Sales')
ax.set_title('Grouped Bar Chart')
ax.set_xticks(x)
ax.set_xticklabels(categories)
ax.legend()

# 添加数值标签
for bar in bars1:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height,
f'{height}', ha='center', va='bottom')
for bar in bars2:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height,
f'{height}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

# 堆叠柱状图
fig, ax = plt.subplots(figsize=(10, 6))
bars1 = ax.bar(categories, product_a, label='Product A', color='#4CAF50')
bars2 = ax.bar(categories, product_b, bottom=product_a,
label='Product B', color='#2196F3')

ax.set_ylabel('Sales')
ax.set_title('Stacked Bar Chart')
ax.legend()
plt.show()

基本饼图

1
2
3
4
5
6
7
8
9
10
11
# 简单饼图
labels = ['Python', 'Java', 'C++', 'JavaScript', 'Other']
sizes = [35, 25, 15, 20, 5]
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A', '#98D8C8']

plt.figure(figsize=(10, 8))
plt.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%',
startangle=90, shadow=True, explode=(0.05, 0.05, 0.05, 0.05, 0))
plt.axis('equal') # 保证饼图为圆形
plt.title('Programming Language Usage', fontsize=14)
plt.show()

环形图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 环形图Donut Chart
fig, ax = plt.subplots(figsize=(10, 8))
wedges, texts, autotexts = ax.pie(sizes, labels=labels, colors=colors,
autopct='%1.1f%%', startangle=90,
pctdistance=0.85)

# 绘制中心圆形成环形
centre_circle = plt.Circle((0, 0), 0.70, fc='white')
fig.gca().add_artist(centre_circle)

# 设置字体
for text in texts:
text.set_fontsize(11)
for autotext in autotexts:
autotext.set_fontsize(10)
autotext.set_color('white')

ax.set_title('Donut Chart', fontsize=14)
plt.axis('equal')
plt.tight_layout()
plt.show()

基本直方图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 简单直方图
data = np.random.randn(1000)

plt.hist(data, bins=30, color='steelblue', edgecolor='black', alpha=0.7)
plt.title('Histogram')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.grid(True, alpha=0.3)
plt.show()

# 多个直方图
data1 = np.random.normal(0, 1, 1000)
data2 = np.random.normal(2, 1.5, 1000)

plt.hist(data1, bins=30, alpha=0.5, label='Dataset 1', color='blue')
plt.hist(data2, bins=30, alpha=0.5, label='Dataset 2', color='red')
plt.legend()
plt.title('Multiple Histograms')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.show()

累积直方图和密度图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# 累积直方图
plt.hist(data, bins=30, cumulative=True, density=True,
histtype='step', linewidth=2, color='purple')
plt.title('Cumulative Histogram')
plt.xlabel('Value')
plt.ylabel('Cumulative Probability')
plt.grid(True, alpha=0.3)
plt.show()

# 直方图 + 密度曲线
from scipy import stats

fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(data, bins=30, density=True, alpha=0.6, color='steelblue',
edgecolor='black', label='Histogram')

# 添加密度曲线
kde = stats.gaussian_kde(data)
x_range = np.linspace(data.min(), data.max(), 100)
ax.plot(x_range, kde(x_range), 'r-', linewidth=2, label='KDE')

ax.set_xlabel('Value')
ax.set_ylabel('Density')
ax.set_title('Histogram with KDE')
ax.legend()
plt.grid(True, alpha=0.3)
plt.show()

高级图表

箱线图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 基本箱线图
data = [np.random.normal(0, std, 100) for std in range(1, 4)]

fig, ax = plt.subplots(figsize=(10, 6))
bp = ax.boxplot(data, labels=['Group 1', 'Group 2', 'Group 3'],
patch_artist=True,
boxprops=dict(facecolor='lightblue', color='blue'),
medianprops=dict(color='red', linewidth=2),
whiskerprops=dict(color='green'),
capprops=dict(color='green'),
flierprops=dict(marker='o', markerfacecolor='red', markersize=6))

ax.set_ylabel('Value')
ax.set_title('Box Plot')
ax.grid(True, alpha=0.3, axis='y')
plt.show()

# 水平箱线图
ax.boxplot(data, vert=False, labels=['Group 1', 'Group 2', 'Group 3'])
ax.set_xlabel('Value')
ax.set_title('Horizontal Box Plot')
plt.show()

热力图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# 基本热力图
data = np.random.rand(10, 10)

fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(data, cmap='YlOrRd', aspect='auto')

# 添加颜色条
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Intensity', rotation=270, labelpad=15)

# 添加网格和标签
ax.set_xticks(np.arange(10))
ax.set_yticks(np.arange(10))
ax.set_xticklabels([f'Col {i}' for i in range(10)])
ax.set_yticklabels([f'Row {i}' for i in range(10)])

# 在每个单元格中添加数值
for i in range(10):
for j in range(10):
text = ax.text(j, i, f'{data[i, j]:.2f}',
ha="center", va="center", color="black", fontsize=8)

ax.set_title('Heatmap')
plt.tight_layout()
plt.show()

# 使用 seaborn 风格如果安装了 seaborn
try:
import seaborn as sns
corr_matrix = np.random.rand(10, 10)
plt.figure(figsize=(10, 8))
sns.heatmap(corr_matrix, annot=True, fmt='.2f', cmap='coolwarm',
square=True, linewidths=0.5)
plt.title('Correlation Heatmap')
plt.show()
except ImportError:
print("Seaborn not installed")

面积图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 堆叠面积图
x = np.arange(0, 10, 0.5)
y1 = np.sin(x) + 2
y2 = np.cos(x) + 2
y3 = np.tan(x) * 0.3 + 2

fig, ax = plt.subplots(figsize=(10, 6))
ax.fill_between(x, 0, y1, alpha=0.4, label='Series 1', color='#FF6B6B')
ax.fill_between(x, y1, y1+y2, alpha=0.4, label='Series 2', color='#4ECDC4')
ax.fill_between(x, y1+y2, y1+y2+y3, alpha=0.4, label='Series 3', color='#45B7D1')

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_title('Stacked Area Chart')
ax.legend(loc='upper left')
plt.grid(True, alpha=0.3)
plt.show()

雷达图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 雷达图
categories = ['Speed', 'Reliability', 'Comfort', 'Safety', 'Efficiency']
values = [4, 3, 5, 4, 3]

# 计算角度
angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False).tolist()
values += values[:1] # 闭合图形
angles += angles[:1]

fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
ax.plot(angles, values, 'o-', linewidth=2, color='steelblue')
ax.fill(angles, values, alpha=0.25, color='steelblue')

# 设置标签
ax.set_xticks(angles[:-1])
ax.set_xticklabels(categories, fontsize=11)
ax.set_ylim(0, 5)
ax.set_title('Radar Chart', fontsize=14, pad=20)
ax.grid(True)

plt.tight_layout()
plt.show()

3D 绘图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from mpl_toolkits.mplot3d import Axes3D

# 3D 散点图
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

x = np.random.rand(100)
y = np.random.rand(100)
z = np.random.rand(100)

ax.scatter(x, y, z, c=z, cmap='viridis', s=50)
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
ax.set_title('3D Scatter Plot')
plt.show()

# 3D 曲面图
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

X = np.arange(-5, 5, 0.25)
Y = np.arange(-5, 5, 0.25)
X, Y = np.meshgrid(X, Y)
R = np.sqrt(X**2 + Y**2)
Z = np.sin(R)

surf = ax.plot_surface(X, Y, Z, cmap='viridis', edgecolor='none')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('3D Surface Plot')
fig.colorbar(surf, shrink=0.5, aspect=10)
plt.show()

子图与布局

subplot 方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# 2x2 子图
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

x = np.linspace(0, 10, 100)

# 第一个子图
axes[0, 0].plot(x, np.sin(x), 'r-')
axes[0, 0].set_title('Sine Wave')
axes[0, 0].grid(True, alpha=0.3)

# 第二个子图
axes[0, 1].plot(x, np.cos(x), 'b-')
axes[0, 1].set_title('Cosine Wave')
axes[0, 1].grid(True, alpha=0.3)

# 第三个子图
axes[1, 0].plot(x, np.tan(x), 'g-')
axes[1, 0].set_title('Tangent Wave')
axes[1, 0].set_ylim(-5, 5)
axes[1, 0].grid(True, alpha=0.3)

# 第四个子图
axes[1, 1].plot(x, x**2, 'm-')
axes[1, 1].set_title('Quadratic')
axes[1, 1].grid(True, alpha=0.3)

plt.suptitle('Multiple Subplots', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

add_subplot 方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 不规则子图
fig = plt.figure(figsize=(12, 8))

ax1 = fig.add_subplot(2, 2, 1)
ax1.plot(x, np.sin(x), 'r-')
ax1.set_title('Plot 1')

ax2 = fig.add_subplot(2, 2, 2)
ax2.plot(x, np.cos(x), 'b-')
ax2.set_title('Plot 2')

ax3 = fig.add_subplot(2, 1, 2) # 占据下半部分
ax3.plot(x, np.sin(x) * np.cos(x), 'g-')
ax3.set_title('Plot 3 (Wide)')

plt.tight_layout()
plt.show()

GridSpec 布局

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import matplotlib.gridspec as gridspec

# 使用 GridSpec
fig = plt.figure(figsize=(12, 10))
gs = gridspec.GridSpec(3, 3, hspace=0.3, wspace=0.3)

# 不同大小的子图
ax1 = fig.add_subplot(gs[0, :]) # 第一行全部
ax1.plot(x, np.sin(x), 'r-')
ax1.set_title('Wide Plot')

ax2 = fig.add_subplot(gs[1, :-1]) # 第二行前两列
ax2.plot(x, np.cos(x), 'b-')
ax2.set_title('Medium Plot')

ax3 = fig.add_subplot(gs[1:, -1]) # 右侧两行
ax3.plot(x, np.tan(x), 'g-')
ax3.set_title('Tall Plot')
ax3.set_ylim(-5, 5)

ax4 = fig.add_subplot(gs[2, 0]) # 左下角
ax4.plot(x, x**2, 'm-')
ax4.set_title('Small Plot 1')

ax5 = fig.add_subplot(gs[2, 1]) # 下中
ax5.plot(x, np.exp(-x/10), 'c-')
ax5.set_title('Small Plot 2')

plt.suptitle('GridSpec Layout', fontsize=16)
plt.show()

inset_axes 插图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# 添加插图
fig, ax = plt.subplots(figsize=(10, 8))

# 主图
x = np.linspace(0, 10, 1000)
ax.plot(x, np.sin(x), 'b-', linewidth=2)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_title('Main Plot with Inset')
ax.grid(True, alpha=0.3)

# 插图放大局部
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

axins = inset_axes(ax, width="40%", height="40%", loc='upper right')
x_zoom = np.linspace(2, 4, 100)
axins.plot(x_zoom, np.sin(x_zoom), 'r-', linewidth=2)
axins.set_xlim(2, 4)
axins.set_ylim(-1, 1)
axins.set_xticks([])
axins.set_yticks([])

# 添加指示框
ax.indicate_inset_zoom(axins, edgecolor="red")

plt.tight_layout()
plt.show()

样式定制

颜色设置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 命名颜色
plt.plot(x, np.sin(x), color='red')
plt.plot(x, np.cos(x), color='blue')

# RGB/RGBA
plt.plot(x, np.sin(x), color=(1, 0, 0)) # RGB
plt.plot(x, np.cos(x), color=(1, 0, 0, 0.5)) # RGBA

# HEX 颜色
plt.plot(x, np.sin(x), color='#FF5733')

# 颜色循环
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A', '#98D8C8']
for i, color in enumerate(colors):
plt.plot(x, np.sin(x + i), color=color, label=f'Line {i+1}')
plt.legend()
plt.show()

Colormap 使用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# 使用 colormap
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

cmaps = ['viridis', 'plasma', 'inferno', 'magma', 'cividis', 'jet']

for ax, cmap in zip(axes.flat, cmaps):
data = np.random.rand(10, 10)
im = ax.imshow(data, cmap=cmap)
ax.set_title(cmap)
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

plt.suptitle('Colormaps', fontsize=16)
plt.tight_layout()
plt.show()

# 连续型和离散型 colormap
from matplotlib.colors import ListedColormap, BoundaryNorm

# 自定义离散 colormap
colors = ['red', 'orange', 'yellow', 'green', 'blue']
cmap = ListedColormap(colors)
bounds = [0, 1, 2, 3, 4, 5]
norm = BoundaryNorm(bounds, cmap.N)

data = np.random.randint(0, 5, (10, 10))
plt.imshow(data, cmap=cmap, norm=norm)
plt.colorbar(ticks=[0.5, 1.5, 2.5, 3.5, 4.5])
plt.title('Discrete Colormap')
plt.show()

字体设置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# 全局字体设置
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.size'] = 12

# 局部字体设置
fig, ax = plt.subplots()
ax.plot(x, np.sin(x))
ax.set_title('Title', fontsize=16, fontweight='bold', family='serif')
ax.set_xlabel('X Label', fontsize=12, style='italic')
ax.set_ylabel('Y Label', fontsize=12)

# 添加文本注释
ax.text(5, 0.5, 'Annotation', fontsize=10,
bbox=dict(boxstyle='round,pad=0.5', facecolor='yellow', alpha=0.5))
plt.show()

# 数学公式
fig, ax = plt.subplots()
ax.plot(x, np.sin(x))
ax.set_title(r'$\alpha > \beta$ and $\int_{0}^{\infty} e^{-x} dx$', fontsize=14)
ax.text(2, 0.5, r'$E = mc^2$', fontsize=16,
bbox=dict(boxstyle='round', facecolor='wheat'))
plt.show()

文本标注

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 箭头标注
fig, ax = plt.subplots()
ax.plot(x, np.sin(x))

ax.annotate('Local Maximum', xy=(np.pi/2, 1), xytext=(3, 1.5),
arrowprops=dict(facecolor='red', shrink=0.05, width=2, headwidth=8),
fontsize=12, color='red')

ax.annotate('Zero Crossing', xy=(np.pi, 0), xytext=(4, -0.5),
arrowprops=dict(arrowstyle='->', color='blue', lw=2),
fontsize=10, color='blue')

ax.set_title('Annotations')
plt.grid(True, alpha=0.3)
plt.show()

图例定制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# 图例位置
fig, ax = plt.subplots()
ax.plot(x, np.sin(x), label='sin(x)')
ax.plot(x, np.cos(x), label='cos(x)')

# 不同位置
# 'best', 'upper right', 'upper left', 'lower left', 'lower right'
# 'right', 'center left', 'center right', 'lower center', 'upper center', 'center'
ax.legend(loc='upper right', fontsize=11, framealpha=0.8,
title='Legend Title', title_fontsize=12)

plt.title('Legend Position')
plt.show()

# 图外图例
fig, ax = plt.subplots(figsize=(10, 6))
for i in range(5):
ax.plot(x, np.sin(x + i), label=f'Line {i+1}')

ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=10)
plt.title('Legend Outside Plot')
plt.tight_layout()
plt.show()

# 多列图例
fig, ax = plt.subplots()
for i in range(8):
ax.plot(x, np.sin(x + i), label=f'Line {i+1}')

ax.legend(ncol=4, loc='upper center', bbox_to_anchor=(0.5, -0.1),
fontsize=9, frameon=False)
plt.title('Multi-column Legend')
plt.tight_layout()
plt.show()

保存图表

基本保存

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 保存为 PNG
plt.savefig('plot.png', dpi=300, bbox_inches='tight')

# 保存为 PDF矢量图
plt.savefig('plot.pdf', bbox_inches='tight')

# 保存为 SVG
plt.savefig('plot.svg', bbox_inches='tight')

# 保存为 EPS
plt.savefig('plot.eps', bbox_inches='tight')

# 高质量保存
plt.savefig('plot_high_quality.png',
dpi=600,
bbox_inches='tight',
pad_inches=0.5,
facecolor='white',
edgecolor='none',
transparent=False)

批量保存

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 批量生成并保存图表
datasets = [np.random.randn(100) for _ in range(5)]

for i, data in enumerate(datasets):
fig, ax = plt.subplots(figsize=(8, 6))
ax.hist(data, bins=20, color='steelblue', edgecolor='black')
ax.set_title(f'Dataset {i+1}')
ax.set_xlabel('Value')
ax.set_ylabel('Frequency')

# 保存
plt.savefig(f'histogram_{i+1}.png', dpi=300, bbox_inches='tight')
plt.close() # 关闭图形释放内存

print("All plots saved!")

格式对比

格式 类型 优点 缺点 适用场景
PNG 位图 兼容性好文件小 放大失真 网页展示演示文稿
JPG 位图 压缩率高 有损压缩质量损失 照片复杂图像
PDF 矢量图 无损缩放专业印刷 文件较大 学术论文出版物
SVG 矢量图 可编辑Web 友好 复杂图表文件大 网页嵌入图标
EPS 矢量图 LaTeX 兼容 老旧格式 LaTeX 文档

与 Pandas 集成

DataFrame 直接绘图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import pandas as pd

# 创建示例数据
df = pd.DataFrame({
'date': pd.date_range('2024-01-01', periods=100),
'value1': np.random.randn(100).cumsum(),
'value2': np.random.randn(100).cumsum() + 10,
'category': np.random.choice(['A', 'B', 'C'], 100)
})
df = df.set_index('date')

# DataFrame 直接绘图
df[['value1', 'value2']].plot(figsize=(12, 6), title='Time Series')
plt.xlabel('Date')
plt.ylabel('Value')
plt.grid(True, alpha=0.3)
plt.show()

# 不同类型图表
df['value1'].plot(kind='line', label='Value 1')
df['value1'].plot(kind='area', alpha=0.3, label='Value 1 Area')
plt.legend()
plt.title('Line and Area')
plt.show()

# 柱状图
df.groupby('category')['value1'].mean().plot(kind='bar', color='steelblue')
plt.title('Category Average')
plt.ylabel('Mean Value')
plt.xticks(rotation=0)
plt.show()

# 箱线图
df.boxplot(column='value1', by='category', figsize=(10, 6))
plt.title('Box Plot by Category')
plt.suptitle('') # 移除自动生成的标题
plt.show()

时间序列可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 时间序列特殊处理
fig, axes = plt.subplots(2, 1, figsize=(12, 10))

# 原始数据
axes[0].plot(df.index, df['value1'], label='Value 1')
axes[0].plot(df.index, df['value2'], label='Value 2')
axes[0].set_title('Original Time Series')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# 滚动平均
df['rolling_mean'] = df['value1'].rolling(window=7).mean()
axes[1].plot(df.index, df['value1'], alpha=0.3, label='Original')
axes[1].plot(df.index, df['rolling_mean'], 'r-', linewidth=2, label='7-day MA')
axes[1].set_title('With Moving Average')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

动画与交互

基本动画

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from matplotlib.animation import FuncAnimation

# 简单动画
fig, ax = plt.subplots()
xdata, ydata = [], []
ln, = ax.plot([], [], 'r-', animated=True)
ax.set_xlim(0, 10)
ax.set_ylim(-1, 1)
ax.set_title('Animation Example')

def init():
return ln,

def update(frame):
xdata.append(frame)
ydata.append(np.sin(frame))
ln.set_data(xdata, ydata)
return ln,

ani = FuncAnimation(fig, update, frames=np.linspace(0, 10, 100),
init_func=init, blit=True, interval=50)

# 保存动画
# ani.save('animation.gif', writer='pillow', fps=20)
# ani.save('animation.mp4', writer='ffmpeg', fps=20)

plt.show()

交互式图表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 使用 widgets需要 ipywidgets
try:
from ipywidgets import interact

def plot_function(amplitude=1.0, frequency=1.0, phase=0.0):
x = np.linspace(0, 10, 1000)
y = amplitude * np.sin(frequency * x + phase)

plt.figure(figsize=(10, 6))
plt.plot(x, y, 'b-', linewidth=2)
plt.ylim(-2, 2)
plt.title(f'Sine Wave: A={amplitude}, f={frequency}, φ={phase}')
plt.grid(True, alpha=0.3)
plt.show()

interact(plot_function,
amplitude=(0.5, 2.0, 0.1),
frequency=(0.5, 5.0, 0.1),
phase=(0, 2*np.pi, 0.1))
except ImportError:
print("ipywidgets not installed")

性能优化

代码规范

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# 1. 使用面向对象接口推荐
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, y)
ax.set_title('Title')
ax.set_xlabel('X')
ax.set_ylabel('Y')
plt.tight_layout()
plt.show()

# 而非 pyplot 状态机接口
plt.figure(figsize=(10, 6))
plt.plot(x, y)
plt.title('Title')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()

# 2. 及时关闭图形
for i in range(10):
fig, ax = plt.subplots()
ax.plot(x, y)
plt.savefig(f'plot_{i}.png')
plt.close(fig) # 释放内存

# 3. 使用上下文管理器
with plt.style.context('seaborn-v0_8'):
fig, ax = plt.subplots()
ax.plot(x, y)
plt.show()

# 4. 设置随机种子保证可复现
np.random.seed(42)
data = np.random.randn(100)

性能优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 1. 减少点数对于大数据集
x_large = np.linspace(0, 10, 100000)
# 降采样
x_sampled = x_large[::100]
y_sampled = np.sin(x_sampled)

# 2. 使用 rasterized 参数
fig, ax = plt.subplots()
ax.scatter(x, y, s=1, rasterized=True) # 散点图栅格化
plt.savefig('large_scatter.pdf') # PDF 文件更小

# 3. 批量操作
# 避免循环绘制
for i in range(1000):
ax.plot(x, y) # 慢

# 一次性绘制
ax.plot(x_matrix.T) # 快

# 4. 选择合适的后端
# 服务器端使用 Agg非交互
matplotlib.use('Agg')

性能问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 问题绘制大量数据点很慢
# 解决方案
# 1. 降采样
x_downsampled = x[::10]
y_downsampled = y[::10]

# 2. 使用 scatter 的 rasterized 参数
ax.scatter(x, y, s=1, rasterized=True)

# 3. 使用 LineCollection大量线段
from matplotlib.collections import LineCollection
segments = [[[x[i], y[i]], [x[i+1], y[i+1]]] for i in range(len(x)-1)]
lc = LineCollection(segments, linewidths=1)
ax.add_collection(lc)

# 4. 考虑使用其他库
# - Plotly交互式适合 Web
# - Bokeh交互式 dashboard
# - Datashader超大规模数据

显示问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# 问题Jupyter 中图表不显示
# 解决方案
# 1. 添加魔法命令
%matplotlib inline

# 2. 或者
%matplotlib notebook # 交互式

# 3. 显式调用 plt.show()
plt.plot(x, y)
plt.show()

# 问题图表模糊
# 解决方案
# 1. 提高 DPI
plt.savefig('plot.png', dpi=300)

# 2. 使用矢量格式
plt.savefig('plot.pdf')
plt.savefig('plot.svg')

# 3. 设置 retina 显示Jupyter
%config InlineBackend.figure_format = 'retina'

常见陷阱

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# 1. 忘记调用 tight_layout
fig, axes = plt.subplots(2, 2)
# ... 绘图代码 ...
plt.tight_layout() # 避免标签重叠
plt.show()

# 2. 中文乱码
# 确保设置了中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

# 3. 颜色循环耗尽
# 手动设置颜色或使用 colormap
colors = plt.cm.viridis(np.linspace(0, 1, num_lines))

# 4. 图例重复
# 只给需要的元素添加 label
for i in range(5):
ax.plot(x, y, label=f'Line {i}' if i == 0 else None)
ax.legend()

# 5. 内存泄漏
# 及时关闭不用的图形
plt.close('all') # 关闭所有图形
gc.collect() # 垃圾回收

Seaborn

Seaborn 是基于 Matplotlib 的高级 Python 数据可视化库提供了简洁的 API 和美观的默认样式它专注于统计图表的绘制能够轻松创建复杂的信息丰富的可视化作品特别适合探索性数据分析

  • 核心功能
    • 统计图表分布图回归图分类图等
    • 关系可视化散点图线图热力图
    • 分类数据箱线图小提琴图条形图
    • 矩阵可视化热力图聚类图
    • 多变量分析成对关系图联合分布图
  • 主要优势
    • 简洁易用API 设计直观代码量少
    • 美观默认内置多种主题和配色方案
    • Pandas 集成直接支持 DataFrame
    • 统计友好自动计算统计量并可视化

核心组件

  • 图形级别函数
    • relplot关系图
    • catplot分类图
    • displot分布图
    • jointplot联合分布图
    • pairplot成对关系图
  • 坐标轴级别函数
    • scatterplotlineplot散点图线图
    • barplotboxplot条形图箱线图
    • histplotkdeplot直方图密度图
    • heatmapclustermap热力图聚类图

工作原理

绘图流程

1
2
3
4
5
6
1. 准备数据 → DataFrame 格式
2. 选择图表类型 → 关系/分类/分布/矩阵
3. 设置美学参数 → 主题配色上下文
4. 调用绘图函数 → 传入数据和映射
5. 定制细节 → 标签标题图例
6. 显示或保存 → plt.show() 或 plt.savefig()

与 Matplotlib 的关系

1
2
3
4
5
Seaborn 构建在 Matplotlib 之上
- Seaborn 提供高级抽象和美观默认值
- Matplotlib 提供底层绘图能力
- 可以混合使用两者
- Seaborn 图表可以用 Matplotlib 进一步定制

环境搭建

pip 安装

1
2
3
4
5
6
7
8
9
10
11
# 基础安装
pip install seaborn

# 指定版本安装
pip install seaborn==0.13.0

# 升级 Seaborn
pip install --upgrade seaborn

# 安装完整功能包含所有可选依赖
pip install seaborn[stats]

conda 安装

1
2
3
4
5
6
7
8
# 使用 conda 安装
conda install seaborn

# 指定版本
conda install seaborn=0.13.0

# 在指定环境中安装
conda install -n myenv seaborn

验证安装

1
2
3
4
5
6
7
8
9
10
11
import seaborn as sns
print(sns.__version__)
# 输出: 0.13.0

# 查看可用数据集
print(sns.get_dataset_names())

# 测试绘图
tips = sns.load_dataset('tips')
sns.scatterplot(data=tips, x='total_bill', y='tip')
plt.show()

依赖环境

1
2
3
4
5
6
7
8
9
10
seaborn 主要依赖
- matplotlib: 底层绘图引擎
- numpy: 数值计算
- pandas: 数据处理
- scipy: 统计计算可选
- statsmodels: 统计模型可选

推荐依赖
- jupyter: 交互式开发
- pillow: 图像处理

💗💗 requirements.txt

1
2
3
4
5
seaborn==0.13.0
matplotlib==3.8.2
numpy==1.26.2
pandas==2.1.4
scipy==1.11.4

主题设置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import seaborn as sns
import matplotlib.pyplot as plt

# 设置主题
sns.set_theme(style="whitegrid") # 白色网格
sns.set_theme(style="darkgrid") # 深色网格
sns.set_theme(style="dark") # 深色背景
sns.set_theme(style="white") # 纯白背景
sns.set_theme(style="ticks") # 刻度线

# 恢复默认
sns.set_theme()

# 永久设置
sns.set_style("whitegrid")

配色方案

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 查看可用配色
print(sns.color_palette())

# 设置调色板
sns.set_palette("husl") # HUSL 色彩空间
sns.set_palette("muted") # 柔和色调
sns.set_palette("pastel") # pastel 色调
sns.set_palette("bright") # 明亮色调
sns.set_palette("deep") # 深色调
sns.set_palette("colorblind") # 色盲友好

# 自定义调色板
custom_palette = sns.color_palette("viridis", 10)
sns.set_palette(custom_palette)

# 分类调色板
categorical_palette = sns.color_palette("Set2", 8)
sequential_palette = sns.color_palette("Blues", 10)
diverging_palette = sns.color_palette("RdBu_r", 10)

上下文设置

1
2
3
4
5
6
7
8
# 设置绘图上下文控制元素大小
sns.set_context("paper") # 论文最小
sns.set_context("notebook") # 笔记本默认
sns.set_context("talk") # 演讲较大
sns.set_context("poster") # 海报最大

# 自定义字体缩放
sns.set_context("notebook", font_scale=1.5)

中文支持

1
2
3
4
5
# 设置中文字体同 Matplotlib
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei']
plt.rcParams['axes.unicode_minus'] = False

# Seaborn 会自动继承 Matplotlib 的字体设置

关系图表

基本散点图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import seaborn as sns
import matplotlib.pyplot as plt

# 加载示例数据
tips = sns.load_dataset('tips')

# 简单散点图
sns.scatterplot(data=tips, x='total_bill', y='tip')
plt.title('Tip vs Total Bill')
plt.show()

# 添加颜色映射
sns.scatterplot(data=tips, x='total_bill', y='tip', hue='sex')
plt.title('Tip by Gender')
plt.show()

# 添加样式和大小编码
sns.scatterplot(data=tips, x='total_bill', y='tip',
hue='day', style='time', size='size',
sizes=(20, 200))
plt.title('Multi-encoded Scatter Plot')
plt.show()

高级散点图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 透明度控制
sns.scatterplot(data=tips, x='total_bill', y='tip',
alpha=0.6, edgecolor=None)
plt.title('Transparent Points')
plt.show()

# 自定义标记
markers = {'Male': 'o', 'Female': 's'}
sns.scatterplot(data=tips, x='total_bill', y='tip',
hue='sex', style='sex', markers=markers)
plt.show()

# 结合回归线
sns.regplot(data=tips, x='total_bill', y='tip',
scatter_kws={'alpha': 0.5},
line_kws={'color': 'red'})
plt.title('Scatter with Regression Line')
plt.show()

基本线图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 加载时间序列数据
flights = sns.load_dataset('flights')

# 简单线图
sns.lineplot(data=flights, x='year', y='passengers')
plt.title('Air Passengers Over Time')
plt.show()

# 多组线图
sns.lineplot(data=flights, x='year', y='passengers', hue='month')
plt.title('Passengers by Month')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

带置信区间的线图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 加载实验数据
fmri = sns.load_dataset('fmri')

# 自动计算均值和置信区间
sns.lineplot(data=fmri, x='timepoint', y='signal', hue='event')
plt.title('fMRI Signal Over Time')
plt.xlabel('Time Point')
plt.ylabel('Signal')
plt.show()

# 禁用置信区间
sns.lineplot(data=fmri, x='timepoint', y='signal',
hue='event', errorbar=None)
plt.title('Without Confidence Interval')
plt.show()

# 自定义置信区间
sns.lineplot(data=fmri, x='timepoint', y='signal',
hue='event', errorbar=('ci', 95)) # 95% 置信区间
plt.show()

样式定制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 不同线条样式
sns.lineplot(data=fmri, x='timepoint', y='signal',
hue='region', style='event',
dashes=True, markers=True)
plt.title('Styled Line Plot')
plt.show()

# 自定义颜色和标记
palette = {'stim': '#FF6B6B', 'cue': '#4ECDC4'}
markers = {'stim': 'o', 'cue': 's'}
sns.lineplot(data=fmri, x='timepoint', y='signal',
hue='event', palette=palette,
markers=markers, linewidth=2)
plt.show()

分面关系图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 创建分面网格
g = sns.relplot(data=tips, x='total_bill', y='tip',
col='time', row='sex',
hue='day', style='smoker',
kind='scatter', height=4, aspect=1.2)

g.fig.suptitle('Faceted Relationship Plot', y=1.02)
plt.show()

# 线图分面
g = sns.relplot(data=fmri, x='timepoint', y='signal',
col='region', hue='event',
kind='line', height=4, aspect=1)

g.fig.suptitle('Faceted Line Plot', y=1.02)
plt.show()

分布图表

基本直方图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 加载数据
penguins = sns.load_dataset('penguins')

# 简单直方图
sns.histplot(data=penguins, x='bill_length_mm')
plt.title('Bill Length Distribution')
plt.show()

# 调整 bins
sns.histplot(data=penguins, x='bill_length_mm', bins=30)
plt.title('More Bins')
plt.show()

# 核密度估计叠加
sns.histplot(data=penguins, x='bill_length_mm', kde=True)
plt.title('Histogram with KDE')
plt.show()

分组直方图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 按类别分组
sns.histplot(data=penguins, x='bill_length_mm', hue='species',
multiple='stack', alpha=0.6)
plt.title('Stacked Histogram')
plt.show()

# 并排显示
sns.histplot(data=penguins, x='bill_length_mm', hue='species',
multiple='dodge', element='step')
plt.title('Dodged Histogram')
plt.show()

# 重叠显示
sns.histplot(data=penguins, x='bill_length_mm', hue='species',
multiple='layer', alpha=0.3)
plt.title('Overlapping Histogram')
plt.show()

二维直方图

1
2
3
4
5
6
7
8
9
10
# 二维直方图
sns.histplot(data=penguins, x='bill_length_mm', y='bill_depth_mm',
bins=20, cmap='viridis')
plt.title('2D Histogram')
plt.show()

# 热力图风格
sns.histplot(data=penguins, x='bill_length_mm', y='bill_depth_mm',
bins=20, cmap='YlOrRd')
plt.show()

基本密度图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 单变量密度图
sns.kdeplot(data=penguins, x='bill_length_mm')
plt.title('Kernel Density Estimate')
plt.show()

# 填充区域
sns.kdeplot(data=penguins, x='bill_length_mm', fill=True, alpha=0.3)
plt.title('Filled KDE')
plt.show()

# 多组密度图
sns.kdeplot(data=penguins, x='bill_length_mm', hue='species',
fill=True, alpha=0.3)
plt.title('Multiple KDE')
plt.show()

二维密度图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 等高线图
sns.kdeplot(data=penguins, x='bill_length_mm', y='bill_depth_mm')
plt.title('2D KDE Contour')
plt.show()

# 填充等高线
sns.kdeplot(data=penguins, x='bill_length_mm', y='bill_depth_mm',
fill=True, cmap='viridis', levels=10)
plt.title('Filled 2D KDE')
plt.show()

# 按类别分组
sns.kdeplot(data=penguins, x='bill_length_mm', y='bill_depth_mm',
hue='species', fill=True, alpha=0.3)
plt.title('Grouped 2D KDE')
plt.show()

灵活分布图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 自动选择最佳展示方式
sns.displot(data=penguins, x='bill_length_mm')
plt.title('Auto Distribution Plot')
plt.show()

# 强制直方图
sns.displot(data=penguins, x='bill_length_mm', kind='hist', kde=True)
plt.title('Histogram + KDE')
plt.show()

# 强制密度图
sns.displot(data=penguins, x='bill_length_mm', kind='kde', fill=True)
plt.title('Density Plot')
plt.show()

# ECDF 图
sns.displot(data=penguins, x='bill_length_mm', kind='ecdf')
plt.title('Empirical CDF')
plt.show()

分组分布

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 按类别分面
g = sns.displot(data=penguins, x='bill_length_mm', hue='species',
kind='kde', fill=True, alpha=0.3)
g.set_title('Distribution by Species')
plt.show()

# 多变量分布
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for idx, var in enumerate(['bill_length_mm', 'bill_depth_mm', 'flipper_length_mm']):
sns.kdeplot(data=penguins, x=var, hue='species',
fill=True, alpha=0.3, ax=axes[idx])
axes[idx].set_title(var)

plt.tight_layout()
plt.show()

分类图表

基本条形图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 聚合数据
titanic = sns.load_dataset('titanic')

# 简单条形图自动计算均值
sns.barplot(data=titanic, x='class', y='age')
plt.title('Average Age by Class')
plt.show()

# 计数条形图
sns.countplot(data=titanic, x='class')
plt.title('Count by Class')
plt.show()

# 水平条形图
sns.barplot(data=titanic, y='class', x='age')
plt.title('Horizontal Bar Plot')
plt.show()

分组条形图

1
2
3
4
5
6
7
8
9
10
11
# 按性别分组
sns.barplot(data=titanic, x='class', y='age', hue='sex')
plt.title('Age by Class and Sex')
plt.legend(title='Sex')
plt.show()

# 自定义颜色和误差线
sns.barplot(data=titanic, x='class', y='age', hue='sex',
palette='Set2', errwidth=2, capsize=0.1)
plt.title('Styled Bar Plot')
plt.show()

基本箱线图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 简单箱线图
sns.boxplot(data=tips, x='day', y='total_bill')
plt.title('Bill Distribution by Day')
plt.show()

# 添加分组
sns.boxplot(data=tips, x='day', y='total_bill', hue='sex')
plt.title('Bill by Day and Sex')
plt.show()

# 水平箱线图
sns.boxplot(data=tips, y='day', x='total_bill')
plt.title('Horizontal Box Plot')
plt.show()

样式定制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 自定义颜色
palette = {'Thur': '#FF6B6B', 'Fri': '#4ECDC4',
'Sat': '#45B7D1', 'Sun': '#FFA07A'}
sns.boxplot(data=tips, x='day', y='total_bill', palette=palette)
plt.title('Custom Colors')
plt.show()

# 移除离群点
sns.boxplot(data=tips, x='day', y='total_bill', showfliers=False)
plt.title('Without Outliers')
plt.show()

# 自定义离群点样式
sns.boxplot(data=tips, x='day', y='total_bill',
flierprops=dict(marker='o', markerfacecolor='red',
markersize=8, linestyle='none'))
plt.show()

基本小提琴图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 简单小提琴图
sns.violinplot(data=tips, x='day', y='total_bill')
plt.title('Violin Plot')
plt.show()

# 内嵌箱线图
sns.violinplot(data=tips, x='day', y='total_bill', inner='box')
plt.title('Violin with Box')
plt.show()

# 内嵌 stick
sns.violinplot(data=tips, x='day', y='total_bill', inner='stick')
plt.title('Violin with Stick')
plt.show()

分割小提琴图

1
2
3
4
5
6
7
8
9
10
11
12
# 按性别分割
sns.violinplot(data=tips, x='day', y='total_bill',
hue='sex', split=True)
plt.title('Split Violin Plot')
plt.show()

# 自定义样式
sns.violinplot(data=tips, x='day', y='total_bill',
hue='sex', split=True,
palette='muted', inner='quartile')
plt.title('Styled Split Violin')
plt.show()

分面分类图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 创建分面网格
g = sns.catplot(data=tips, x='day', y='total_bill',
col='time', kind='box', height=4, aspect=1)

g.fig.suptitle('Faceted Box Plot', y=1.02)
plt.show()

# 小提琴图分面
g = sns.catplot(data=tips, x='day', y='total_bill',
row='sex', kind='violin', height=3, aspect=2)

g.fig.suptitle('Faceted Violin Plot', y=1.02)
plt.show()

# 条形图分面
g = sns.catplot(data=titanic, x='class', y='survived',
col='sex', kind='bar', height=4, aspect=1)

g.fig.suptitle('Survival Rate by Class and Sex', y=1.02)
plt.show()

矩阵图表

基本热力图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 创建相关矩阵
flights = sns.load_dataset('flights')
flights_pivot = flights.pivot_table(index='month',
columns='year',
values='passengers')

# 简单热力图
sns.heatmap(flights_pivot)
plt.title('Flight Passengers Heatmap')
plt.show()

# 添加数值标注
sns.heatmap(flights_pivot, annot=True, fmt='d')
plt.title('Heatmap with Values')
plt.show()

# 自定义颜色
sns.heatmap(flights_pivot, annot=True, fmt='d', cmap='YlGnBu')
plt.title('Custom Colormap')
plt.show()

高级热力图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# 相关系数热力图
corr_matrix = penguins.corr(numeric_only=True)

sns.heatmap(corr_matrix, annot=True, fmt='.2f', cmap='coolwarm',
square=True, linewidths=0.5, center=0)
plt.title('Correlation Heatmap')
plt.tight_layout()
plt.show()

# 掩码上三角
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
sns.heatmap(corr_matrix, mask=mask, annot=True, fmt='.2f',
cmap='coolwarm', square=True, linewidths=0.5)
plt.title('Lower Triangle Correlation')
plt.tight_layout()
plt.show()

# 自定义刻度
sns.heatmap(flights_pivot, annot=True, fmt='d', cmap='viridis',
linewidths=0.5, linecolor='gray',
cbar_kws={'label': 'Passengers'})
plt.title('Styled Heatmap')
plt.show()

层次聚类热力图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# 加载基因表达数据
genes = sns.load_dataset('genes')

# 基本聚类图
g = sns.clustermap(genes.pivot_table(index='gene',
columns='cell_type',
values='expression'),
cmap='viridis')
g.fig.suptitle('Clustermap', y=1.02)
plt.show()

# 自定义聚类
g = sns.clustermap(flights_pivot, cmap='YlOrRd',
standard_scale=1, # 按列标准化
figsize=(10, 8))
g.fig.suptitle('Standardized Clustermap', y=1.02)
plt.show()

# 行和列聚类
g = sns.clustermap(corr_matrix.abs(), cmap='Reds',
row_cluster=True, col_cluster=True,
figsize=(8, 8))
g.fig.suptitle('Hierarchical Clustering', y=1.02)
plt.show()

多变量图表

基本成对图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 简单成对图
sns.pairplot(penguins.dropna())
plt.suptitle('Pairplot', y=1.02)
plt.show()

# 按类别着色
sns.pairplot(penguins.dropna(), hue='species')
plt.suptitle('Colored Pairplot', y=1.02)
plt.show()

# 自定义对角线
sns.pairplot(penguins.dropna(), hue='species',
diag_kind='kde', plot_kws={'alpha': 0.6})
plt.suptitle('KDE Diagonal', y=1.02)
plt.show()

高级成对图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 选择特定变量
vars = ['bill_length_mm', 'bill_depth_mm', 'flipper_length_mm']
sns.pairplot(penguins.dropna(), vars=vars, hue='species',
corner=True) # 只显示下三角
plt.suptitle('Corner Pairplot', y=1.02)
plt.show()

# 自定义图表类型
g = sns.pairplot(penguins.dropna(), hue='species',
diag_kind='hist',
plot_kws={'marker': 'D', 's': 50, 'alpha': 0.6},
diag_kws={'edgecolor': 'black'})
g.fig.suptitle('Customized Pairplot', y=1.02)
plt.show()

基本联合图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 散点 + 边际分布
sns.jointplot(data=penguins, x='bill_length_mm', y='bill_depth_mm')
plt.suptitle('Joint Plot', y=1.02)
plt.show()

# 六边形 bin
sns.jointplot(data=penguins, x='bill_length_mm', y='bill_depth_mm',
kind='hex', gridsize=20)
plt.suptitle('Hexbin Joint Plot', y=1.02)
plt.show()

# 核密度估计
sns.jointplot(data=penguins, x='bill_length_mm', y='bill_depth_mm',
kind='kde', fill=True)
plt.suptitle('KDE Joint Plot', y=1.02)
plt.show()

回归联合图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 散点 + 回归线
sns.jointplot(data=tips, x='total_bill', y='tip',
kind='reg', scatter_kws={'alpha': 0.5})
plt.suptitle('Regression Joint Plot', y=1.02)
plt.show()

# 按类别分组
g = sns.jointplot(data=tips, x='total_bill', y='tip',
hue='sex', kind='scatter')
g.fig.suptitle('Grouped Joint Plot', y=1.02)
plt.show()

# 残差图
sns.residplot(data=tips, x='total_bill', y='tip')
plt.title('Residual Plot')
plt.show()

回归图表

基本回归图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 简单回归
sns.regplot(data=tips, x='total_bill', y='tip')
plt.title('Linear Regression')
plt.show()

# 多项式回归
sns.regplot(data=tips, x='total_bill', y='tip', order=2)
plt.title('Polynomial Regression (order=2)')
plt.show()

# Logistic 回归
titanic = sns.load_dataset('titanic')
sns.regplot(data=titanic, x='age', y='survived',
logistic=True, y_jitter=0.05)
plt.title('Logistic Regression')
plt.show()

分面回归图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# lmplot 创建分面网格
g = sns.lmplot(data=tips, x='total_bill', y='tip',
col='time', hue='smoker',
height=4, aspect=1)

g.fig.suptitle('Faceted Regression', y=1.02)
plt.show()

# 多变量回归
g = sns.lmplot(data=penguins.dropna(),
x='bill_length_mm', y='bill_depth_mm',
hue='species', col='sex',
height=4, aspect=1)

g.fig.suptitle('Multi-variable Regression', y=1.02)
plt.show()

主题和样式

内置主题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 查看所有主题
print(sns.axes_style())

# 临时切换主题
with sns.axes_style("whitegrid"):
sns.scatterplot(data=tips, x='total_bill', y='tip')
plt.title('Whitegrid Theme')
plt.show()

# 永久设置
sns.set_style("darkgrid", {"grid.color": ".9", "grid.linestyle": "--"})
sns.scatterplot(data=tips, x='total_bill', y='tip')
plt.title('Custom Grid')
plt.show()

自定义样式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 完全自定义
custom_style = {
'axes.facecolor': '#EAEAF2',
'axes.edgecolor': 'white',
'axes.grid': True,
'grid.color': 'white',
'grid.linestyle': '-',
'font.family': 'sans-serif',
'font.size': 12
}

sns.set_style(custom_style)
sns.scatterplot(data=tips, x='total_bill', y='tip')
plt.title('Fully Customized')
plt.show()

调色板类型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 定性调色板分类数据
qualitative = sns.color_palette("Set1", 8)
sns.palplot(qualitative)
plt.title('Qualitative Palette')
plt.show()

# 顺序调色板连续数据
sequential = sns.color_palette("Blues", 10)
sns.palplot(sequential)
plt.title('Sequential Palette')
plt.show()

# 发散调色板偏离中心
diverging = sns.color_palette("RdBu_r", 10)
sns.palplot(diverging)
plt.title('Diverging Palette')
plt.show()

自定义调色板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# 从颜色列表创建
custom = sns.color_palette(["#FF6B6B", "#4ECDC4", "#45B7D1", "#FFA07A"])
sns.palplot(custom)
plt.title('Custom Colors')
plt.show()

# 渐变色
gradient = sns.light_palette("navy", reverse=True)
sns.palplot(gradient)
plt.title('Light Palette')
plt.show()

# 暗色系
dark = sns.dark_palette("purple")
sns.palplot(dark)
plt.title('Dark Palette')
plt.show()

# 混合调色板
mixed = sns.blend_palette(["blue", "yellow"], 10)
sns.palplot(mixed)
plt.title('Blended Palette')
plt.show()

上下文和字体

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 不同上下文对比
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

contexts = ['paper', 'notebook', 'talk', 'poster']
for ax, ctx in zip(axes.flat, contexts):
with sns.plotting_context(ctx):
sns.scatterplot(data=tips.head(20), x='total_bill', y='tip', ax=ax)
ax.set_title(f'Context: {ctx}')

plt.tight_layout()
plt.show()

# 自定义字体缩放
sns.set_context("notebook", font_scale=1.5,
rc={"lines.linewidth": 2.5})
sns.scatterplot(data=tips, x='total_bill', y='tip')
plt.title('Scaled Context')
plt.show()

与 Pandas 集成

DataFrame 直接绘图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# 创建示例数据
np.random.seed(42)
df = pd.DataFrame({
'date': pd.date_range('2024-01-01', periods=100),
'value1': np.random.randn(100).cumsum(),
'value2': np.random.randn(100).cumsum() + 10,
'category': np.random.choice(['A', 'B', 'C'], 100),
'group': np.random.choice(['X', 'Y'], 100)
})

# 时间序列线图
sns.lineplot(data=df, x='date', y='value1', label='Value 1')
sns.lineplot(data=df, x='date', y='value2', label='Value 2')
plt.title('Time Series')
plt.legend()
plt.show()

# 分类箱线图
sns.boxplot(data=df, x='category', y='value1')
plt.title('Category Distribution')
plt.show()

# 分组条形图
summary = df.groupby(['category', 'group'])['value1'].mean().reset_index()
sns.barplot(data=summary, x='category', y='value1', hue='group')
plt.title('Grouped Summary')
plt.show()

统计摘要可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 描述性统计
desc = df.describe()
print(desc)

# 统计图表
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# 分布
sns.histplot(data=df, x='value1', kde=True, ax=axes[0, 0])
axes[0, 0].set_title('Distribution')

# 箱线图
sns.boxplot(data=df, y='value1', ax=axes[0, 1])
axes[0, 1].set_title('Box Plot')

# 小提琴图
sns.violinplot(data=df, x='category', y='value1', ax=axes[1, 0])
axes[1, 0].set_title('Violin Plot')

# 相关热力图
corr = df[['value1', 'value2']].corr()
sns.heatmap(corr, annot=True, ax=axes[1, 1], cmap='coolwarm')
axes[1, 1].set_title('Correlation')

plt.tight_layout()
plt.show()

性能优化

代码规范

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# 1. 使用数据驱动的方式
# 好
sns.scatterplot(data=df, x='column1', y='column2', hue='category')

# 不好
sns.scatterplot(x=df['column1'], y=df['column2'])

# 2. 链式操作保持清晰
g = (
sns.catplot(data=df, x='category', y='value',
col='group', kind='box')
.set_titles("{col_name}")
.set_axis_labels("Category", "Value")
)

# 3. 及时关闭图形
for category in df['category'].unique():
subset = df[df['category'] == category]
sns.histplot(data=subset, x='value')
plt.title(category)
plt.savefig(f'{category}.png')
plt.close()

# 4. 使用上下文管理器
with sns.axes_style("whitegrid"):
sns.scatterplot(data=df, x='x', y='y')
plt.show()

性能优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# 1. 大数据集降采样
large_df = pd.DataFrame({'x': np.random.randn(100000),
'y': np.random.randn(100000)})

# 降采样
sampled = large_df.sample(n=1000, random_state=42)
sns.scatterplot(data=sampled, x='x', y='y')
plt.show()

# 2. 使用 alpha 透明度处理过绘
sns.scatterplot(data=large_df.sample(10000), x='x', y='y',
alpha=0.1, s=10)
plt.show()

# 3. 选择合适的图表类型
# 大量数据用 hexbin 而非 scatter
sns.jointplot(data=large_df.sample(10000), x='x', y='y',
kind='hex', gridsize=30)
plt.show()

# 4. 避免不必要的计算
# 预计算统计量再绘图
summary = df.groupby('category')['value'].agg(['mean', 'std']).reset_index()
sns.barplot(data=summary, x='category', y='mean')
plt.show()

常见陷阱

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# 1. 忘记处理缺失值
# 错误
sns.scatterplot(data=df, x='x', y='y') # 可能有 NaN

# 正确
sns.scatterplot(data=df.dropna(subset=['x', 'y']), x='x', y='y')

# 2. 颜色映射不当
# 错误对分类数据使用顺序调色板
sns.scatterplot(data=df, x='x', y='y', hue='category',
palette='Blues') # 不合适

# 正确使用定性调色板
sns.scatterplot(data=df, x='x', y='y', hue='category',
palette='Set1')

# 3. 图例混乱
# 太多类别时考虑分面而非颜色编码
# 错误
sns.scatterplot(data=df, x='x', y='y', hue='many_categories')

# 正确
sns.relplot(data=df, x='x', y='y', col='many_categories',
kind='scatter')

# 4. 忽略数据分布
# 检查分布再选择图表
sns.displot(data=df, x='value') # 先了解分布
plt.show()

性能问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 问题绘制大量数据点很慢
# 解决方案
# 1. 降采样
df_sampled = df.sample(n=1000, random_state=42)

# 2. 使用 alpha 和较小标记
sns.scatterplot(data=df_sampled, x='x', y='y', alpha=0.3, s=10)

# 3. 使用 hexbin 或 kde
sns.jointplot(data=df_sampled, x='x', y='y', kind='hex')

# 4. 考虑其他库
# - Plotly交互式适合 Web
# - Datashader超大规模数据
# - Bokeh大型数据集

美观性问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 问题图表不够美观
# 解决方案
# 1. 使用内置主题
sns.set_theme(style="whitegrid", palette="muted")

# 2. 选择合适的配色
sns.set_palette("husl")

# 3. 调整上下文
sns.set_context("talk", font_scale=1.2)

# 4. 添加标题和标签
plt.title('Clear Title', fontsize=14, fontweight='bold')
plt.xlabel('X Label', fontsize=12)
plt.ylabel('Y Label', fontsize=12)

# 5. 调整布局
plt.tight_layout()

学习资源

  • 视频
    • 数据分析全套视频教程https://www.bilibili.com/video/BV1ReshzoEgG
  • 书籍
    • 利用 Jupyter 进行数据科学社区编写
    • Python 数据科学手册Jake VanderPlas 著包含 Jupyter 章节
    • Jupyter 交互式数据分析社区编写
  • 工具
    • Google Colab云端 Jupyter 环境
    • Kaggle Notebooks数据科学竞赛平台
    • Binder可分享的在线 Notebook
    • NBViewer在线查看 Notebook
  • Anaconda
    • 官方文档https://docs.anaconda.com/
  • Jupyter
    • Jupyter 官方文档https://jupyter.org/documentation
    • Jupyter Notebook 文档https://jupyter-notebook.readthedocs.io/
    • JupyterLab 文档https://jupyterlab.readthedocs.io/
    • Jupyter GitHubhttps://github.com/jupyter
  • NumPy
    • NumPy 官方文档:https://numpy.org/doc/stable/
    • NumPy 中文文档:https://www.numpy.org.cn/
    • NumPy GitHub:https://github.com/numpy/numpy
  • SQLAlchemyhttps://docs.sqlalchemy.org/
    • SQLAlchemy 官方文档https://docs.sqlalchemy.org/
    • SQLAlchemy GitHubhttps://github.com/sqlalchemy/sqlalchemy
  • Matplotlib
    • Matplotlib 图表模版https://matplotlib.org/stable/gallery/index.html