RNNの評価

RNNの評価(混同行列の出力・分析)

次のステップとして RNNの評価(混同行列の出力・分析) に進みます。
RNNモデルの性能を詳細に分析し、強み・弱みを理解することで、CNNやTransformerとの比較がスムーズになります。


RNNの評価: 混同行列の出力・分析

追加する内容

  • 混同行列(Confusion Matrix)の計算
  • Accuracy, Precision, Recall, F1-score の算出
  • クラスごとの分類性能の可視化

1. 必要なライブラリのインポート

import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

2. モデルとテストデータのロード

# 保存したモデルをロード
model_path = os.path.join(DATA_DIR, "rnn_model.keras")
model = tf.keras.models.load_model(model_path)

# テストデータの再ロード
X_test = np.load(os.path.join(DATA_DIR, "test_features.npy"))
y_test = np.load(os.path.join(DATA_DIR, "test_labels.npy"))

# RNNの入力形状に変換
X_test = X_test.reshape((X_test.shape[0], X_test.shape[1], 1))

3. 予測結果の取得

# モデルの予測
y_pred_prob = model.predict(X_test)
y_pred = np.argmax(y_pred_prob, axis=1)

4. 混同行列の計算と可視化

# 混同行列の計算
conf_matrix = confusion_matrix(y_test, y_pred)

# 混同行列の可視化
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=np.unique(y_test), yticklabels=np.unique(y_test))
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix - RNN Model")
plt.show()

5. 精度評価(Accuracy, Precision, Recall, F1-score)

# クラスごとの詳細レポート
print("\nClassification Report:\n")
print(classification_report(y_test, y_pred, digits=4))

まとめ

このステップでは、RNNの評価 を行い、以下を確認しました。

  • 混同行列の可視化 により、分類ミスのパターンを分析
  • Accuracy, Precision, Recall, F1-score を算出し、モデルの性能を評価

この分析を基に、次の CNNモデルの実装 に進みます。×RNNの結果が良くないので改善中