plink --vcf all1.vcf --make-bed --out all
plink -bfile all --pca
生成特征值文件plink.eigenvec,特征向量文件plink.eigenval
绘图
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Plink PCA Result Visualization Tool
Function: Read plink.eigenvec and plink.eigenval → Auto-adapt PC count → Plot PCA scatter plots → Save results
Dependencies: pandas, numpy, matplotlib, seaborn (core libraries only)
"""
import os
import argparse
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description="Plink PCA Result Visualization Tool (PC count mismatch compatible)")
parser.add_argument("-v", "--eigenvec", default="plink.eigenvec",
help="Path to Plink eigenvector file (default: plink.eigenvec)")
parser.add_argument("-l", "--eigenval", default="plink.eigenval",
help="Path to Plink eigenvalue file (default: plink.eigenval)")
parser.add_argument("-o", "--output", default="plink_pca_results",
help="Output directory (default: plink_pca_results)")
parser.add_argument("-g", "--group", help="Sample group file (optional, format: SampleID\tGroupName)")
parser.add_argument("-p", "--pcs", nargs="+", default=["1-2", "1-3"],
help="PC pairs to plot (default: 1-2 1-3, format: x-y, e.g., 2-4)")
return parser.parse_args()
def read_plink_pca(eigenvec_path, eigenval_path):
"""
Read Plink PCA result files (fixed PC count mismatch issue)
Input:
- eigenvec_path: Path to plink.eigenvec (PC scores)
- eigenval_path: Path to plink.eigenval (eigenvalues)
Output:
- df_pca: PCA scores DataFrame (rows=samples, columns=PC1~PCn)
- pca_stats: PCA statistics (PC, eigenvalue, explained variance ratio, cumulative ratio)
"""
print(f"Reading Plink PCA result files...")
# 1. Read eigenvector file (plink.eigenvec)
# Format: FID IID PC1 PC2 ... PCn (may have redundant columns)
df_eigenvec = pd.read_csv(
eigenvec_path,
sep="\s+", # Whitespace-separated (Plink default output format)
header=None,
engine="python" # Compatible with complex separators
)
# Merge FID and IID to unique sample ID
df_eigenvec["sample_id"] = df_eigenvec[0] + "_" + df_eigenvec[1] # FID_IID
# PC columns start from 3rd column (index 2) in eigenvec file
pc_cols_eigenvec = df_eigenvec.iloc[:, 2:].columns # Indices of all PC columns
n_pc_eigenvec = len(pc_cols_eigenvec)
print(f"Eigenvector file: {df_eigenvec.shape[0]} samples × {n_pc_eigenvec} PCs (raw)")
# 2. Read eigenvalue file (plink.eigenval)
# Format: One eigenvalue per line, ordered by PC1~PCn
eigenvals = np.loadtxt(eigenval_path)
n_pc_eigenval = len(eigenvals)
print(f"Eigenvalue file: {n_pc_eigenval} eigenvalues (corresponding to PC1~PCn)")
# 3. Auto-adapt PC count (use eigenvalue count as reference, truncate eigenvec if needed)
if n_pc_eigenvec != n_pc_eigenval:
print(f"Warning: PC count mismatch! Eigenvector file has {n_pc_eigenvec} PCs, eigenvalue file has {n_pc_eigenval} PCs.")
print(f"Automatically truncating eigenvector file to first {n_pc_eigenval} PCs")
# Truncate to first n_pc_eigenval PC columns
df_pc_truncated = df_eigenvec.iloc[:, 2:2+n_pc_eigenval].copy()
else:
df_pc_truncated = df_eigenvec.iloc[:, 2:].copy()
# Build PCA scores DataFrame
pc_columns = [f"PC{i+1}" for i in range(n_pc_eigenval)]
df_pca = pd.DataFrame(
data=df_pc_truncated.values,
columns=pc_columns
)
df_pca["sample_id"] = df_eigenvec["sample_id"].values # Add sample ID column
print(f"Final PCA scores: {df_pca.shape[0]} samples × {n_pc_eigenval} PCs")
# 4. Calculate explained variance ratio and cumulative ratio
total_variance = np.sum(eigenvals)
explained_variance = eigenvals / total_variance
cumulative_variance = np.cumsum(explained_variance)
# Build statistics DataFrame
pca_stats = pd.DataFrame({
"Principal_Component": pc_columns,
"Eigenvalue": eigenvals.round(4),
"Explained_Variance_Ratio": explained_variance.round(4),
"Cumulative_Variance_Ratio": cumulative_variance.round(4)
})
print("\nPCA Statistics (Top 10 PCs):")
print(pca_stats.head(10).to_string(index=False))
return df_pca, pca_stats
def validate_pc_pairs(pc_pairs, max_pc):
"""Validate input PC pairs"""
valid_pairs = []
for pair in pc_pairs:
if "-" not in pair:
print(f"Warning: Invalid PC pair format ({pair}), skipped (correct format: x-y, e.g., 1-2)")
continue
try:
pc_x, pc_y = map(int, pair.split("-"))
if pc_x < 1 or pc_y < 1 or pc_x > max_pc or pc_y > max_pc:
print(f"Warning: PC pair {pair} out of range (max PC: {max_pc}), skipped")
continue
if pc_x == pc_y:
print(f"Warning: PC pair {pair} cannot be the same PC, skipped")
continue
valid_pairs.append((f"PC{pc_x}", f"PC{pc_y}"))
except ValueError:
print(f"Warning: Invalid PC pair format ({pair}), skipped")
continue
if not valid_pairs:
raise ValueError("No valid PC pairs provided, program exited")
return valid_pairs
def plot_plink_pca(df_pca, pca_stats, valid_pc_pairs, group_file=None, output_dir="plink_pca_results"):
"""
Plot PCA scatter plots
Supports group coloring, explained variance display, high-resolution saving
"""
print(f"\nStarting PCA plot generation (output directory: {output_dir})")
os.makedirs(output_dir, exist_ok=True)
# Set plot style
plt.rcParams['font.sans-serif'] = ['Arial', 'Helvetica'] # English fonts
plt.rcParams['axes.unicode_minus'] = False
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (10, 8) # Uniform plot size
# Process sample groups
if group_file:
# Read group file (format: SampleID\tGroupName)
df_group = pd.read_csv(
group_file,
sep="\t",
header=None,
names=["sample_id", "group"],
engine="python"
)
# Merge group information with PCA results
df_pca_plot = pd.merge(
df_pca,
df_group,
on="sample_id",
how="left"
)
# Label ungrouped samples as "Unknown"
df_pca_plot["group"] = df_pca_plot["group"].fillna("Unknown")
hue = "group"
# Auto-assign colors for multiple groups
n_groups = df_pca_plot["group"].nunique()
palette = sns.color_palette("tab10", n_colors=n_groups)
print(f"Group information: {n_groups} groups detected")
else:
df_pca_plot = df_pca.copy()
hue = None
palette = "#2E86AB" # Single color for all samples
print("No group file provided, all samples plotted in the same color")
# Create dictionary for explained variance ratio (for plot labels)
var_dict = dict(zip(pca_stats["Principal_Component"], pca_stats["Explained_Variance_Ratio"]))
# Plot each valid PC pair
for pc_x, pc_y in valid_pc_pairs:
fig, ax = plt.subplots()
# Scatter plot
sns.scatterplot(
data=df_pca_plot,
x=pc_x,
y=pc_y,
hue=hue,
palette=palette,
s=80, # Marker size
alpha=0.8, # Transparency
edgecolor="black",
linewidth=0.5, # Marker border
ax=ax
)
# Set title and axis labels (with explained variance ratio)
var_x = var_dict[pc_x]
var_y = var_dict[pc_y]
ax.set_title(
f"PCA Analysis: {pc_x} ({var_x:.1%}) vs {pc_y} ({var_y:.1%})",
fontsize=14,
fontweight="bold"
)
ax.set_xlabel(f"{pc_x} ({var_x:.1%})", fontsize=12)
ax.set_ylabel(f"{pc_y} ({var_y:.1%})", fontsize=12)
# Adjust legend position (avoid overlapping)
if hue:
ax.legend(
title="Group",
bbox_to_anchor=(1.05, 1),
loc="upper left",
fontsize=10
)
# Save plot (high resolution 300 DPI)
output_path = os.path.join(output_dir, f"pca_{pc_x}_{pc_y}.png")
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches="tight")
plt.close()
print(f"→ Saved: {output_path}")
# Plot explained variance ratio bar chart
plot_variance(pca_stats, output_dir)
def plot_variance(pca_stats, output_dir):
"""Plot PC explained variance ratio bar chart (with cumulative variance curve)"""
fig, ax = plt.subplots(figsize=(10, 6))
# Plot top 10 PCs (avoid overcrowding, adjustable)
top_n = min(10, len(pca_stats))
pca_top = pca_stats.head(top_n)
x = range(1, top_n+1)
# Bar chart: Explained variance ratio
ax.bar(
x,
pca_top["Explained_Variance_Ratio"],
alpha=0.8,
color="#2E86AB",
edgecolor="black",
label="Individual PC Explained Variance"
)
# Line chart: Cumulative explained variance ratio
ax.plot(
x,
pca_top["Cumulative_Variance_Ratio"],
color="#A23B72",
marker="o",
linewidth=2,
markersize=6,
label="Cumulative Explained Variance"
)
# Set labels and title
ax.set_xlabel("Principal Component", fontsize=12)
ax.set_ylabel("Explained Variance Ratio", fontsize=12)
ax.set_title("PCA Explained Variance Ratio (Top 10 PCs)", fontsize=14, fontweight="bold")
ax.set_xticks(x)
ax.set_xticklabels([f"PC{i}" for i in x])
ax.legend(fontsize=10)
# Add value labels
for i, (var, cum_var) in enumerate(zip(pca_top["Explained_Variance_Ratio"], pca_top["Cumulative_Variance_Ratio"])):
ax.text(i+1, var + 0.005, f"{var:.1%}", ha="center", fontsize=9)
ax.text(i+1, cum_var + 0.005, f"{cum_var:.1%}", ha="center", fontsize=9, color="#A23B72")
# Save plot
output_path = os.path.join(output_dir, "pca_variance.png")
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches="tight")
plt.close()
print(f"→ Saved: {output_path}")
def save_results(df_pca, pca_stats, output_dir="plink_pca_results"):
"""Save PCA results to files"""
os.makedirs(output_dir, exist_ok=True)
# Save PCA scores (with sample ID)
df_pca.to_csv(
os.path.join(output_dir, "pca_scores.csv"),
sep=",",
index=False,
encoding="utf-8"
)
# Save PCA statistics
pca_stats.to_csv(
os.path.join(output_dir, "pca_stats.csv"),
sep=",",
index=False,
encoding="utf-8"
)
print(f"\nNumerical results saved to: {output_dir}")
print(" - pca_scores.csv: PCA scores for each sample")
print(" - pca_stats.csv: PCA statistics (eigenvalue, explained variance ratio)")
def main():
args = parse_args()
# 1. Read Plink PCA results (compatible with PC count mismatch)
df_pca, pca_stats = read_plink_pca(args.eigenvec, args.eigenval)
# 2. Validate and parse PC pairs
max_pc = len(pca_stats)
valid_pc_pairs = validate_pc_pairs(args.pcs, max_pc)
print(f"\nPC pairs to plot: {[f'{x}-{y}' for x, y in valid_pc_pairs]}")
# 3. Generate PCA plots
plot_plink_pca(df_pca, pca_stats, valid_pc_pairs, args.group, args.output)
# 4. Save results
save_results(df_pca, pca_stats, args.output)
print("\n=== Plink PCA Visualization Completed Successfully! ===")
print(f"Result files located at: {os.path.abspath(args.output)}")
if __name__ == "__main__":
main()