こんにちは
GoogleColab上でAutoGluonを使うための備忘録です。
今回はTabularPredictorによる二値分類を行いますが、Taskを変えてあげれば画像分類なども可能です。
詳しくは以下ドキュメント
AutoGluon ドキュメント
AutoGluon Tasks — AutoGluon Documentation 0.3.1 documentation
準備
trainデータ等があるGoogleDriveのマウント
# ドライブのマウント from google.colab import drive drive.mount('/content/drive') basePath = "/content/drive/MyDrive/BERT/"
AutoGluonのインストール
!pip uninstall -y mkl
!pip install --upgrade mxnet-cu100
!pip install autogluon
!pip install -U ipykernel
# and restart runtime
Import
import autogluon as ag from autogluon.tabular import TabularPredictor
色々なサイトを見てるとTabularPredictionと記述していたりするけど、AutoGluon最新版では名前が変更された?
データの読み込み
ドロップさせるデータ列がない場合は.dropは必要なし
train_data = TabularPredictor.Dataset('/content/drive/MyDrive/train.csv').drop(labels=["カラム名1","カラム名2"],axis=1) test_data = TabularPredictor.Dataset('/content/drive/MyDrive/test.csv').drop(labels=["カラム名1","カラム名2"],axis=1) train_data.head() # 確認
訓練
label_column="正解ラベル" predictor = TabularPredictor(label=label_column, path="./AutoGluonModel", eval_metric="recall").fit(train_data, presets="best_quality", ag_args_fit={'num_gpus': 1})
特筆すべき引数
- eval_metric
モデルの評価に用いる指標
今回私の場合は正解データが極端に少ない学習をしたので、Recallを評価指標とした。
options
['accuracy'、 'balanced_accuracy'、 'f1'、 'f1_macro'、 'f1_micro'、 'f1_weighted'、 'roc_auc'、 'roc_auc_ovo_macro'、 'average_precision'、 'precision'、 'precision_macro'、 'precision_micro'、 ' Precision_weighted '、' recall '、' recall_macro '、' recall_micro '、' recall_weighted '、' log_loss '、' pac_score '] - presets
best_qualityを設定することで、より高精度なモデルの学習となる。
その分もちろん時間はかかる - ag_args_fit
GPUを使う場合は{'num_gpus': 1}を指定
推論
# 学習済みモデルのロード predictor = TabularPredictor.load("./AutoGluon") # unnecessary, just demonstrates how to load previously-trained predictor from file # 推論用データから正解ラベルをドロップ y_test = test_data[label_column] test_data_nolab = test_data.drop(labels=label_column,axis=1) train_data.head() # 推論 y_pred = predictor.predict(test_data_nolab) # 推論結果 print("Predictions: \n", y_pred)
その他学習の確認など
results = predictor.fit_summary()
各モデルの学習情報が列挙される。
推論自体はこの中でもっともスコアがいいものが使用されている。
以上