import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.gridspec import GridSpec
# -------------------------- 基础配置(重点:新增样本展示调整参数)--------------------------
input_file = "data.txt" # 输入文件路径
output_fig = "sample_optimized_analysis.pdf" # 输出图片
# 样本展示调整参数(可直接修改)
sample_label_rotation = 60 # 样本标签旋转角度(0~90,推荐45/60/90)
sample_sort_by = "performance" # 样本排序:"performance"(性能最优→最差)或 "name"(名称字母序)
filter_samples = None # 筛选要展示的样本(如["Sample1", "Sample3"],None=展示所有)
sample_label_fontsize_offset = 0 # 样本标签字体偏移(+1=放大,-1=缩小)
max_samples_display = 30 # 最大展示样本数(超过则自动截取前30个,避免过度拥挤)
# 其他基础配置
cl_color_palette = plt.cm.Set3
nl_color_palette = plt.cm.Pastel1
top_cl_sort_by = "Total_Records"
max_display_cl = 8 # 最大展示CL数量(进一步减小,避免和样本冲突)
max_display_nl = 6 # 最大展示NL数量(进一步减小)
# -------------------------- 工具函数:拆分主类型 --------------------------
def extract_main_type(type_str):
"""提取主类型,忽略“-”后的重复后缀"""
if "-" in type_str:
return type_str.split("-")[0].strip()
return type_str.strip()
# -------------------------- 1. 数据加载和预处理 --------------------------
print("Loading data...")
df = pd.read_csv(
input_file,
sep="\t", # 空格分隔改为sep="\\s+"
header=None,
names=["Sample_Name", "CL", "NL", "Value"],
dtype={"Sample_Name": str, "CL": str, "NL": str}
)
# 数据清洗
df_clean = df.dropna(subset=["Sample_Name", "CL", "NL", "Value"]).copy()
df_clean = df_clean[df_clean["CL"].str.startswith("CL", na=False)]
df_clean = df_clean[df_clean["NL"].str.startswith("NL", na=False)]
df_clean["Value"] = pd.to_numeric(df_clean["Value"], errors="coerce")
df_clean = df_clean.dropna(subset=["Value"])
# 提取CL/NL主类型(重复合并)
df_clean["CL_Main"] = df_clean["CL"].apply(extract_main_type)
df_clean["NL_Main"] = df_clean["NL"].apply(extract_main_type)
# 样本筛选(可选)
if filter_samples is not None:
df_clean = df_clean[df_clean["Sample_Name"].isin(filter_samples)]
print(f"Filtered samples: {filter_samples}")
# 样本排序
sample_summary = df_clean.groupby("Sample_Name")["Value"].agg("mean").reset_index()
if sample_sort_by == "performance":
sample_order = sample_summary.sort_values("Value", ascending=True)["Sample_Name"].tolist() # 性能最优→最差
else:
sample_order = sorted(df_clean["Sample_Name"].unique()) # 名称字母序
# 限制最大展示样本数
if len(sample_order) > max_samples_display:
sample_order = sample_order[:max_samples_display]
df_clean = df_clean[df_clean["Sample_Name"].isin(sample_order)]
print(f"Too many samples! Displaying top {max_samples_display} samples (sorted by {sample_sort_by})")
n_samples = len(sample_order)
n_cls = df_clean["CL_Main"].nunique()
n_nls = df_clean["NL_Main"].nunique()
print(f"Final display: {n_samples} samples | {n_cls} CL main types | {n_nls} NL main types")
if len(df_clean) == 0:
raise ValueError("No valid data after filtering! Check sample names in 'filter_samples'")
# -------------------------- 核心:细化样本数适配参数(更精准)--------------------------
print(f"\nOptimizing layout for {n_samples} samples...")
# 1. 图表大小(按样本数细分区间)
if n_samples <= 8:
fig_size = (18, 14)
elif n_samples <= 15:
fig_size = (20 + (n_samples-8)//2, 14) # 每多2个样本加1宽度
elif n_samples <= 25:
fig_size = (24 + (n_samples-15)//3, 16) # 每多3个样本加1宽度
else:
fig_size = (28, 18) # 最多28宽度,避免文件过大
# 2. 基础字体大小(细分区间)
if n_samples <= 8:
base_font = 9
elif n_samples <= 15:
base_font = 8
elif n_samples <= 25:
base_font = 7
else:
base_font = 6
font_size = base_font + sample_label_fontsize_offset # 支持手动偏移
font_size = max(5, font_size) # 最小5号字
# 3. 柱宽(细分区间,确保同一样本的多个柱不重叠)
if n_samples <= 8:
bar_width = 0.15
elif n_samples <= 15:
bar_width = 0.12
elif n_samples <= 25:
bar_width = 0.1
else:
bar_width = 0.08
# 4. TOP CL/NL展示数量(进一步减少,避免柱状图拥挤)
if n_samples <= 8:
top_cl_count = min(6, n_cls)
top_nl_count = min(6, n_nls)
elif n_samples <= 15:
top_cl_count = min(5, n_cls)
top_nl_count = min(5, n_nls)
elif n_samples <= 25:
top_cl_count = min(4, n_cls)
top_nl_count = min(4, n_nls)
else:
top_cl_count = min(3, n_cls)
top_nl_count = min(3, n_nls)
# 5. 热力图展示数量(进一步压缩)
display_cl_count = min(6 if n_samples <= 15 else 4, n_cls)
display_nl_count = min(5 if n_samples <= 15 else 3, n_nls)
print(f"Optimized params: fig_size={fig_size}, font_size={font_size}, bar_width={bar_width}")
print(f"TOP display: CL={top_cl_count}, NL={top_nl_count} | Heatmap: CL={display_cl_count}, NL={display_nl_count}")
# -------------------------- 2. 统计分析 --------------------------
print("\n" + "="*80)
print(f"Statistical Analysis: {n_samples} Samples (Optimized Display) | Smaller Value = Better")
print("="*80)
# 详细统计(按样本排序)
detail_stats = df_clean.groupby(["Sample_Name", "CL_Main", "NL_Main"])["Value"].agg([
"count", "mean", "std", "min", "max", "median"
]).round(2)
detail_stats.columns = ["Record_Count", "Mean", "Std_Dev", "Minimum", "Maximum", "Median"]
detail_stats = detail_stats.reset_index()
detail_stats["Sample_Name"] = pd.Categorical(detail_stats["Sample_Name"], categories=sample_order, ordered=True)
detail_stats = detail_stats.sort_values(["Sample_Name", "Mean"], ascending=True)
# 汇总统计
sample_summary = df_clean.groupby("Sample_Name")["Value"].agg([
"count", "mean", "std", "min", "max", "median"
]).round(2)
sample_summary.columns = ["Total_Records", "Overall_Mean", "Overall_Std", "Min_Value", "Max_Value", "Median_Value"]
sample_summary = sample_summary.reindex(sample_order) # 按配置排序
cl_summary = df_clean.groupby("CL_Main")["Value"].agg([
"count", "mean", "std", "min", "max"
]).round(2)
cl_summary.columns = ["Total_Records", "Mean", "Std_Dev", "Minimum", "Maximum"]
cl_summary = cl_summary.sort_values(top_cl_sort_by, ascending=top_cl_sort_by!="Mean")
nl_summary = df_clean.groupby("NL_Main")["Value"].agg([
"count", "mean", "std", "min", "max"
]).round(2)
nl_summary.columns = ["Total_Records", "Mean", "Std_Dev", "Minimum", "Maximum"]
nl_summary = nl_summary.sort_values("Mean", ascending=True)
# 关键指标
best_sample = sample_summary["Overall_Mean"].idxmin()
global_metrics = {
"Best_Performing_Sample": best_sample,
"Best_Sample_Mean": sample_summary.loc[best_sample, "Overall_Mean"],
"Total_Displayed_Samples": n_samples,
"Sample_Sort_Method": sample_sort_by,
"Overall_Data_Mean": round(df_clean["Value"].mean(), 2)
}
print("\nGlobal Key Metrics:")
for metric, value in global_metrics.items():
print(f"{metric.replace('_', ' ')}: {value}")
# -------------------------- 3. 保存统计表格 --------------------------
detail_stats.to_csv("sample_optimized_detail.csv", sep="\t", index=False)
sample_summary.to_csv("sample_optimized_sample_summary.csv", sep="\t")
cl_summary.to_csv("sample_optimized_cl_summary.csv", sep="\t")
nl_summary.to_csv("sample_optimized_nl_summary.csv", sep="\t")
# -------------------------- 4. 样本优化可视化 --------------------------
print("\nGenerating sample-optimized visualization...")
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['figure.autolayout'] = True
plt.rcParams['xtick.labelsize'] = font_size - 1
plt.rcParams['ytick.labelsize'] = font_size - 1
plt.rcParams['legend.fontsize'] = font_size - 1
plt.rcParams['axes.labelsize'] = font_size
plt.rcParams['axes.titlesize'] = font_size + 1
# 准备数据
unique_cls = sorted(cl_summary.index)
unique_nls = sorted(nl_summary.index)
cl_color_size = cl_color_palette.N
nl_color_size = nl_color_palette.N
# 子图布局(增加垂直间距,避免样本标签重叠)
fig = plt.figure(figsize=fig_size)
gs = GridSpec(3, 2, figure=fig, height_ratios=[1, 1.6, 1.1]) # 增加热力图和样本图高度
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[1, :])
ax4 = fig.add_subplot(gs[2, 0])
ax5 = fig.add_subplot(gs[2, 1])
# 标题(标注样本优化信息)
fig.suptitle(
f"Optimized Analysis: {n_samples} Samples (Sorted by {sample_sort_by}) | Smaller Value = Better",
fontsize=font_size + 3, fontweight="bold", y=0.98
)
# 4.1 子图1:样本×TOP CL(优化柱宽和偏移)
top_cls = cl_summary.head(top_cl_count).index
x = np.arange(n_samples)
offsets = np.linspace(-bar_width*(top_cl_count-1)/2, bar_width*(top_cl_count-1)/2, top_cl_count)
for i, cl in enumerate(top_cls):
means = []
for s in sample_order:
mask = (df_clean["Sample_Name"] == s) & (df_clean["CL_Main"] == cl)
mean_val = df_clean[mask]["Value"].mean() if mask.sum() > 0 else np.nan
means.append(mean_val)
color = cl_color_palette(i % cl_color_size)
ax1.bar(x + offsets[i], means, bar_width, label=cl, color=color, alpha=0.8, edgecolor="black")
ax1.set_title(f"Sample vs Top {top_cl_count} CL Main Types", fontsize=font_size + 1)
ax1.set_xlabel("Sample Name", fontsize=font_size)
ax1.set_ylabel("Mean Value (Smaller = Better)", fontsize=font_size)
ax1.set_xticks(x)
ax1.set_xticklabels(sample_order, rotation=sample_label_rotation, ha="right", fontsize=font_size - 1)
# 图例位置优化(样本多则移至顶部)
legend_loc = "upper right" if n_samples <= 15 else "upper center"
ax1.legend(loc=legend_loc, bbox_to_anchor=(1, 1) if n_samples <= 15 else (0.5, 1.2), ncol=top_cl_count)
ax1.grid(True, alpha=0.3, axis="y")
# 4.2 子图2:样本×TOP NL(同上)
top_nls = nl_summary.head(top_nl_count).index
offsets_nl = np.linspace(-bar_width*(top_nl_count-1)/2, bar_width*(top_nl_count-1)/2, top_nl_count)
for i, nl in enumerate(top_nls):
means = []
for s in sample_order:
mask = (df_clean["Sample_Name"] == s) & (df_clean["NL_Main"] == nl)
mean_val = df_clean[mask]["Value"].mean() if mask.sum() > 0 else np.nan
means.append(mean_val)
color = nl_color_palette(i % nl_color_size)
ax2.bar(x + offsets_nl[i], means, bar_width, label=nl, color=color, alpha=0.8, edgecolor="black")
ax2.set_title(f"Sample vs Top {top_nl_count} NL Main Types", fontsize=font_size + 1)
ax2.set_xlabel("Sample Name", fontsize=font_size)
ax2.set_ylabel("Mean Value (Smaller = Better)", fontsize=font_size)
ax2.set_xticks(x)
ax2.set_xticklabels(sample_order, rotation=sample_label_rotation, ha="right", fontsize=font_size - 1)
ax2.legend(loc=legend_loc, bbox_to_anchor=(1, 1) if n_samples <= 15 else (0.5, 1.2), ncol=top_nl_count)
ax2.grid(True, alpha=0.3, axis="y")
# 4.3 子图3:热力图(优化样本标签和列标签)
cl_display = unique_cls[:display_cl_count]
nl_display = unique_nls[:display_nl_count]
heatmap_data = df_clean[
(df_clean["CL_Main"].isin(cl_display)) & (df_clean["NL_Main"].isin(nl_display))
].pivot_table(
index="Sample_Name",
columns=["CL_Main", "NL_Main"],
values="Value",
aggfunc="mean",
fill_value=np.nan
).reindex(sample_order)
if not heatmap_data.empty:
im = ax3.imshow(heatmap_data.values, cmap="RdYlBu_r", aspect="auto")
# 样本标签(垂直排列更清晰)
ax3.set_yticks(np.arange(len(heatmap_data.index)))
ax3.set_yticklabels(heatmap_data.index, fontsize=font_size - 1, rotation=0)
# 列标签(简化显示)
col_labels = [f"{cl}-{nl}" for cl, nl in heatmap_data.columns] # 用短横线代替换行
ax3.set_xticks(np.arange(len(col_labels)))
ax3.set_xticklabels(col_labels, fontsize=font_size - 2, rotation=45 if len(col_labels) <= 10 else 60)
# 标注NaN值(缩小字体)
for i in range(len(heatmap_data.index)):
for j in range(len(heatmap_data.columns)):
if np.isnan(heatmap_data.iloc[i, j]):
ax3.text(j, i, "N", ha="center", va="center", fontsize=font_size - 2, color="black")
# 颜色条(适配字体)
cbar = plt.colorbar(im, ax=ax3, shrink=0.7)
cbar.set_label("Mean Value (Smaller = Better)", fontsize=font_size)
ax3.set_title(
f"Mean Value Heatmap: Sample × CL × NL (Blue = Better)",
fontsize=font_size + 1
)
ax3.set_xlabel("CL-NL Main Type", fontsize=font_size)
ax3.set_ylabel("Sample Name", fontsize=font_size)
# 4.4 子图4:CL×NL散点图(优化点大小和图例)
cl_nl_mean = df_clean.groupby(["CL_Main", "NL_Main"])["Value"].agg(["mean", "count"]).reset_index()
cl_nl_mean = cl_nl_mean.sort_values("mean", ascending=True)
for i, cl in enumerate(unique_cls[:5]): # 最多5个CL,避免图例拥挤
cl_data = cl_nl_mean[cl_nl_mean["CL_Main"] == cl]
if len(cl_data) == 0:
continue
color = cl_color_palette(i % cl_color_size)
ax4.scatter(
range(len(cl_data)),
cl_data["mean"],
s=cl_data["count"]*2 if n_samples > 15 else cl_data["count"]*3,
color=color,
label=cl,
alpha=0.8,
edgecolor="black"
)
ax4.set_title("CL×NL Mean (Smaller = Better, Size=Data Volume)", fontsize=font_size + 1)
ax4.set_xlabel("Combination Index", fontsize=font_size)
ax4.set_ylabel("Mean Value", fontsize=font_size)
ax4.legend(loc="upper right", bbox_to_anchor=(1.3, 1) if n_samples <= 15 else (1.2, 1))
ax4.grid(True, alpha=0.3, axis="y")
# 4.5 子图5:样本均值排序(重点优化样本标签)
sample_means = sample_summary["Overall_Mean"].values
sample_colors = plt.cm.Set2(np.linspace(0, 1, n_samples))
# 柱宽适配样本数
bar_width_sample = 0.9 if n_samples <= 8 else 0.8 if n_samples <= 15 else 0.7 if n_samples <= 25 else 0.6
ax5.bar(range(n_samples), sample_means, color=sample_colors, alpha=0.8, edgecolor="black", width=bar_width_sample)
# 标注最佳CL/NL(样本多则简化标签)
label_font = font_size - 2 if n_samples <= 15 else font_size - 3
label_font = max(4, label_font)
for i, s in enumerate(sample_order):
sample_data = df_clean[df_clean["Sample_Name"] == s]
if len(sample_data) > 0:
best_comb = sample_data.loc[sample_data["Value"].idxmin()]
# 简化标签(CL-NL,而非换行)
label = f"{best_comb['CL_Main']}-{best_comb['NL_Main']}"
# 标签位置(样本多则移至柱形内部,避免重叠)
y_pos = sample_means[i] + (sample_means.max() - sample_means.min())*0.01 if n_samples <= 15 else sample_means[i]
va = "bottom" if n_samples <= 15 else "center"
color = "black" if n_samples <= 15 else "white"
ax5.text(
i, y_pos, label, ha="center", va=va, fontsize=label_font,
bbox=dict(boxstyle="round,pad=0.1", fc="white" if n_samples <= 15 else "transparent", alpha=0.7),
color=color
)
ax5.set_title("Sample Overall Mean (Best → Worst)", fontsize=font_size + 1)
ax5.set_xlabel("Sample Name", fontsize=font_size)
ax5.set_ylabel("Overall Mean Value (Smaller = Better)", fontsize=font_size)
ax5.set_xticks(range(n_samples))
ax5.set_xticklabels(sample_order, rotation=sample_label_rotation, ha="right", fontsize=font_size - 1)
ax5.grid(True, alpha=0.3, axis="y")
# 最终布局调整(大幅增加间距,避免重叠)
plt.subplots_adjust(
top=0.95,
right=0.85 if n_samples > 15 else 0.9,
left=0.1 if n_samples > 20 else 0.08, # 样本多则增加左间距
hspace=0.4, # 增加垂直间距
wspace=0.3
)
# 保存图片(样本多则适当降低dpi,避免文件过大)
dpi = 300 if n_samples <= 15 else 200 if n_samples <= 25 else 150
plt.savefig(output_fig, dpi=dpi, bbox_inches="tight", format="pdf")
plt.close()
print(f"Sample-optimized visualization saved to: {output_fig}")
print(f"\nOptimization completed! Key adjustments:")
print(f"- Sample label rotation: {sample_label_fontsize_offset}°")
print(f"- Sample sort method: {sample_sort_by}")
print(f"- Displayed samples: {n_samples}")
print(f"- Font size: {font_size} | Bar width: {bar_width}")