PCA分析

plink --vcf all1.vcf --make-bed --out all


plink -bfile all --pca

生成特征值文件plink.eigenvec,特征向量文件plink.eigenval

绘图
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Plink PCA Result Visualization Tool
Function: Read plink.eigenvec and plink.eigenval → Auto-adapt PC count → Plot PCA scatter plots → Save results
Dependencies: pandas, numpy, matplotlib, seaborn (core libraries only)
"""

import os
import argparse
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def parse_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="Plink PCA Result Visualization Tool (PC count mismatch compatible)")
    parser.add_argument("-v", "--eigenvec", default="plink.eigenvec", 
                        help="Path to Plink eigenvector file (default: plink.eigenvec)")
    parser.add_argument("-l", "--eigenval", default="plink.eigenval", 
                        help="Path to Plink eigenvalue file (default: plink.eigenval)")
    parser.add_argument("-o", "--output", default="plink_pca_results", 
                        help="Output directory (default: plink_pca_results)")
    parser.add_argument("-g", "--group", help="Sample group file (optional, format: SampleID\tGroupName)")
    parser.add_argument("-p", "--pcs", nargs="+", default=["1-2", "1-3"], 
                        help="PC pairs to plot (default: 1-2 1-3, format: x-y, e.g., 2-4)")
    return parser.parse_args()

def read_plink_pca(eigenvec_path, eigenval_path):
    """
    Read Plink PCA result files (fixed PC count mismatch issue)
    Input:
      - eigenvec_path: Path to plink.eigenvec (PC scores)
      - eigenval_path: Path to plink.eigenval (eigenvalues)
    Output:
      - df_pca: PCA scores DataFrame (rows=samples, columns=PC1~PCn)
      - pca_stats: PCA statistics (PC, eigenvalue, explained variance ratio, cumulative ratio)
    """
    print(f"Reading Plink PCA result files...")
    
    # 1. Read eigenvector file (plink.eigenvec)
    # Format: FID IID PC1 PC2 ... PCn (may have redundant columns)
    df_eigenvec = pd.read_csv(
        eigenvec_path, 
        sep="\s+",  # Whitespace-separated (Plink default output format)
        header=None,
        engine="python"  # Compatible with complex separators
    )
    # Merge FID and IID to unique sample ID
    df_eigenvec["sample_id"] = df_eigenvec[0] + "_" + df_eigenvec[1]  # FID_IID
    # PC columns start from 3rd column (index 2) in eigenvec file
    pc_cols_eigenvec = df_eigenvec.iloc[:, 2:].columns  # Indices of all PC columns
    n_pc_eigenvec = len(pc_cols_eigenvec)
    print(f"Eigenvector file: {df_eigenvec.shape[0]} samples × {n_pc_eigenvec} PCs (raw)")
    
    # 2. Read eigenvalue file (plink.eigenval)
    # Format: One eigenvalue per line, ordered by PC1~PCn
    eigenvals = np.loadtxt(eigenval_path)
    n_pc_eigenval = len(eigenvals)
    print(f"Eigenvalue file: {n_pc_eigenval} eigenvalues (corresponding to PC1~PCn)")
    
    # 3. Auto-adapt PC count (use eigenvalue count as reference, truncate eigenvec if needed)
    if n_pc_eigenvec != n_pc_eigenval:
        print(f"Warning: PC count mismatch! Eigenvector file has {n_pc_eigenvec} PCs, eigenvalue file has {n_pc_eigenval} PCs.")
        print(f"Automatically truncating eigenvector file to first {n_pc_eigenval} PCs")
        # Truncate to first n_pc_eigenval PC columns
        df_pc_truncated = df_eigenvec.iloc[:, 2:2+n_pc_eigenval].copy()
    else:
        df_pc_truncated = df_eigenvec.iloc[:, 2:].copy()
    
    # Build PCA scores DataFrame
    pc_columns = [f"PC{i+1}" for i in range(n_pc_eigenval)]
    df_pca = pd.DataFrame(
        data=df_pc_truncated.values,
        columns=pc_columns
    )
    df_pca["sample_id"] = df_eigenvec["sample_id"].values  # Add sample ID column
    print(f"Final PCA scores: {df_pca.shape[0]} samples × {n_pc_eigenval} PCs")
    
    # 4. Calculate explained variance ratio and cumulative ratio
    total_variance = np.sum(eigenvals)
    explained_variance = eigenvals / total_variance
    cumulative_variance = np.cumsum(explained_variance)
    
    # Build statistics DataFrame
    pca_stats = pd.DataFrame({
        "Principal_Component": pc_columns,
        "Eigenvalue": eigenvals.round(4),
        "Explained_Variance_Ratio": explained_variance.round(4),
        "Cumulative_Variance_Ratio": cumulative_variance.round(4)
    })
    
    print("\nPCA Statistics (Top 10 PCs):")
    print(pca_stats.head(10).to_string(index=False))
    
    return df_pca, pca_stats

def validate_pc_pairs(pc_pairs, max_pc):
    """Validate input PC pairs"""
    valid_pairs = []
    for pair in pc_pairs:
        if "-" not in pair:
            print(f"Warning: Invalid PC pair format ({pair}), skipped (correct format: x-y, e.g., 1-2)")
            continue
        try:
            pc_x, pc_y = map(int, pair.split("-"))
            if pc_x < 1 or pc_y < 1 or pc_x > max_pc or pc_y > max_pc:
                print(f"Warning: PC pair {pair} out of range (max PC: {max_pc}), skipped")
                continue
            if pc_x == pc_y:
                print(f"Warning: PC pair {pair} cannot be the same PC, skipped")
                continue
            valid_pairs.append((f"PC{pc_x}", f"PC{pc_y}"))
        except ValueError:
            print(f"Warning: Invalid PC pair format ({pair}), skipped")
            continue
    if not valid_pairs:
        raise ValueError("No valid PC pairs provided, program exited")
    return valid_pairs

def plot_plink_pca(df_pca, pca_stats, valid_pc_pairs, group_file=None, output_dir="plink_pca_results"):
    """
    Plot PCA scatter plots
    Supports group coloring, explained variance display, high-resolution saving
    """
    print(f"\nStarting PCA plot generation (output directory: {output_dir})")
    os.makedirs(output_dir, exist_ok=True)
    
    # Set plot style
    plt.rcParams['font.sans-serif'] = ['Arial', 'Helvetica']  # English fonts
    plt.rcParams['axes.unicode_minus'] = False
    sns.set_style("whitegrid")
    plt.rcParams['figure.figsize'] = (10, 8)  # Uniform plot size
    
    # Process sample groups
    if group_file:
        # Read group file (format: SampleID\tGroupName)
        df_group = pd.read_csv(
            group_file, 
            sep="\t", 
            header=None, 
            names=["sample_id", "group"],
            engine="python"
        )
        # Merge group information with PCA results
        df_pca_plot = pd.merge(
            df_pca, 
            df_group, 
            on="sample_id", 
            how="left"
        )
        # Label ungrouped samples as "Unknown"
        df_pca_plot["group"] = df_pca_plot["group"].fillna("Unknown")
        hue = "group"
        # Auto-assign colors for multiple groups
        n_groups = df_pca_plot["group"].nunique()
        palette = sns.color_palette("tab10", n_colors=n_groups)
        print(f"Group information: {n_groups} groups detected")
    else:
        df_pca_plot = df_pca.copy()
        hue = None
        palette = "#2E86AB"  # Single color for all samples
        print("No group file provided, all samples plotted in the same color")
    
    # Create dictionary for explained variance ratio (for plot labels)
    var_dict = dict(zip(pca_stats["Principal_Component"], pca_stats["Explained_Variance_Ratio"]))
    
    # Plot each valid PC pair
    for pc_x, pc_y in valid_pc_pairs:
        fig, ax = plt.subplots()
        
        # Scatter plot
        sns.scatterplot(
            data=df_pca_plot,
            x=pc_x,
            y=pc_y,
            hue=hue,
            palette=palette,
            s=80,  # Marker size
            alpha=0.8,  # Transparency
            edgecolor="black",
            linewidth=0.5,  # Marker border
            ax=ax
        )
        
        # Set title and axis labels (with explained variance ratio)
        var_x = var_dict[pc_x]
        var_y = var_dict[pc_y]
        ax.set_title(
            f"PCA Analysis: {pc_x} ({var_x:.1%}) vs {pc_y} ({var_y:.1%})",
            fontsize=14,
            fontweight="bold"
        )
        ax.set_xlabel(f"{pc_x} ({var_x:.1%})", fontsize=12)
        ax.set_ylabel(f"{pc_y} ({var_y:.1%})", fontsize=12)
        
        # Adjust legend position (avoid overlapping)
        if hue:
            ax.legend(
                title="Group",
                bbox_to_anchor=(1.05, 1),
                loc="upper left",
                fontsize=10
            )
        
        # Save plot (high resolution 300 DPI)
        output_path = os.path.join(output_dir, f"pca_{pc_x}_{pc_y}.png")
        plt.tight_layout()
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        plt.close()
        print(f"→ Saved: {output_path}")
    
    # Plot explained variance ratio bar chart
    plot_variance(pca_stats, output_dir)

def plot_variance(pca_stats, output_dir):
    """Plot PC explained variance ratio bar chart (with cumulative variance curve)"""
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Plot top 10 PCs (avoid overcrowding, adjustable)
    top_n = min(10, len(pca_stats))
    pca_top = pca_stats.head(top_n)
    
    x = range(1, top_n+1)
    # Bar chart: Explained variance ratio
    ax.bar(
        x,
        pca_top["Explained_Variance_Ratio"],
        alpha=0.8,
        color="#2E86AB",
        edgecolor="black",
        label="Individual PC Explained Variance"
    )
    # Line chart: Cumulative explained variance ratio
    ax.plot(
        x,
        pca_top["Cumulative_Variance_Ratio"],
        color="#A23B72",
        marker="o",
        linewidth=2,
        markersize=6,
        label="Cumulative Explained Variance"
    )
    
    # Set labels and title
    ax.set_xlabel("Principal Component", fontsize=12)
    ax.set_ylabel("Explained Variance Ratio", fontsize=12)
    ax.set_title("PCA Explained Variance Ratio (Top 10 PCs)", fontsize=14, fontweight="bold")
    ax.set_xticks(x)
    ax.set_xticklabels([f"PC{i}" for i in x])
    ax.legend(fontsize=10)
    
    # Add value labels
    for i, (var, cum_var) in enumerate(zip(pca_top["Explained_Variance_Ratio"], pca_top["Cumulative_Variance_Ratio"])):
        ax.text(i+1, var + 0.005, f"{var:.1%}", ha="center", fontsize=9)
        ax.text(i+1, cum_var + 0.005, f"{cum_var:.1%}", ha="center", fontsize=9, color="#A23B72")
    
    # Save plot
    output_path = os.path.join(output_dir, "pca_variance.png")
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"→ Saved: {output_path}")

def save_results(df_pca, pca_stats, output_dir="plink_pca_results"):
    """Save PCA results to files"""
    os.makedirs(output_dir, exist_ok=True)
    
    # Save PCA scores (with sample ID)
    df_pca.to_csv(
        os.path.join(output_dir, "pca_scores.csv"),
        sep=",",
        index=False,
        encoding="utf-8"
    )
    # Save PCA statistics
    pca_stats.to_csv(
        os.path.join(output_dir, "pca_stats.csv"),
        sep=",",
        index=False,
        encoding="utf-8"
    )
    print(f"\nNumerical results saved to: {output_dir}")
    print("  - pca_scores.csv: PCA scores for each sample")
    print("  - pca_stats.csv: PCA statistics (eigenvalue, explained variance ratio)")

def main():
    args = parse_args()
    
    # 1. Read Plink PCA results (compatible with PC count mismatch)
    df_pca, pca_stats = read_plink_pca(args.eigenvec, args.eigenval)
    
    # 2. Validate and parse PC pairs
    max_pc = len(pca_stats)
    valid_pc_pairs = validate_pc_pairs(args.pcs, max_pc)
    print(f"\nPC pairs to plot: {[f'{x}-{y}' for x, y in valid_pc_pairs]}")
    
    # 3. Generate PCA plots
    plot_plink_pca(df_pca, pca_stats, valid_pc_pairs, args.group, args.output)
    
    # 4. Save results
    save_results(df_pca, pca_stats, args.output)
    
    print("\n=== Plink PCA Visualization Completed Successfully! ===")
    print(f"Result files located at: {os.path.abspath(args.output)}")

if __name__ == "__main__":
    main()
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容