61 lines
2.0 KiB
Python
61 lines
2.0 KiB
Python
import matplotlib.pyplot as plt
|
|
import pandas as pd
|
|
import numpy as np
|
|
import re
|
|
|
|
# Load and parse raw data from a text file
|
|
def parse_raw_data(file_path):
|
|
pattern = re.compile(r"\((.*?)\)\s+([\d.]+)\s+\(([\d.eE+,-]+)\)")
|
|
data = []
|
|
with open(file_path, "r") as f:
|
|
for line in f:
|
|
match = pattern.search(line)
|
|
if match:
|
|
activation_combo = match.group(1)
|
|
learning_rate = float(match.group(2))
|
|
top_accuracies = tuple(map(float, match.group(3).split(",")))
|
|
data.append((activation_combo, learning_rate, top_accuracies))
|
|
return data
|
|
|
|
# Replace this with your actual path
|
|
file_path = "./parsed.txt"
|
|
data = parse_raw_data(file_path)
|
|
|
|
# Convert to DataFrame
|
|
df = pd.DataFrame(data, columns=["activation_combo", "learning_rate", "accuracy"])
|
|
df[["top1", "top3", "top5"]] = pd.DataFrame(df["accuracy"].tolist(), index=df.index)
|
|
|
|
# Unique sorted learning rates and activation combos
|
|
learning_rates = sorted(df["learning_rate"].unique())
|
|
activation_combos = df["activation_combo"].unique()
|
|
|
|
# Settings for bar positions
|
|
bar_width = 0.2
|
|
x = np.arange(len(learning_rates))
|
|
|
|
# Start plotting
|
|
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(14, 16), sharex=True)
|
|
|
|
# Function to draw bars
|
|
def plot_bars(ax, column, title, ylabel):
|
|
for i, combo in enumerate(activation_combos):
|
|
combo_data = df[df["activation_combo"] == combo].sort_values("learning_rate")
|
|
ax.bar(x + i * bar_width, combo_data[column], width=bar_width, label=combo)
|
|
ax.set_ylabel(ylabel)
|
|
ax.set_title(title)
|
|
ax.legend(title="Activation Combo")
|
|
|
|
# Plot each accuracy type
|
|
plot_bars(ax1, "top1", "Top-1 Accuracy", "Top-1 Accuracy")
|
|
plot_bars(ax2, "top3", "Top-3 Accuracy", "Top-3 Accuracy")
|
|
plot_bars(ax3, "top5", "Top-5 Accuracy", "Top-5 Accuracy")
|
|
|
|
# X-axis ticks for learning rates
|
|
ax3.set_xticks(x + bar_width * (len(activation_combos) - 1) / 2)
|
|
ax3.set_xticklabels([str(lr) for lr in learning_rates])
|
|
ax3.set_xlabel("Learning Rate")
|
|
|
|
# Final layout
|
|
plt.tight_layout()
|
|
plt.show()
|