« tcpdfで日本語フォント | メイン | numpy onehotに変換 »

2020年06月19日

tensorfklow2カスタマイズング

https://www.atmarkit.co.jp/ait/articles/2003/23/news024.html

基本パターン
 TensorFlow 2.xでは、

活性化関数: tf.keras.layers.Layerクラスをサブクラス化
レイヤー: tf.keras.layers.Layerクラスをサブクラス化
オプティマイザ(最適化アルゴリズム): tf.keras.optimizers.Optimizerクラスをサブクラス化
損失関数: tf.keras.losses.Lossクラスをサブクラス化
評価関数: tf.keras.metrics.Metricクラスをサブクラス化
などの機能がカスタマイズ可能である


各損失関数の中身は ここ Numpy backend にあった
----------------------------------
# カスタムの損失関数のPython関数を実装
def custom_loss(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_true - y_pred))

# カスタムの損失関数クラスを実装(レイヤーのサブクラス化)
# (tf.keras.losses.MeanSquaredError()の代わり)
class CustomLoss(tf.keras.losses.Loss):
  def __init__(self, name="custom_loss", **kwargs):
    super(CustomLoss, self).__init__(name=name, **kwargs)

  def call(self, y_true, y_pred):
    y_pred = tf.convert_to_tensor(y_pred)  # 念のためTensor化
    y_true = tf.cast(y_true, y_pred.dtype) # 念のため同じデータ型化
    return custom_loss(y_true, y_pred)
----------------------------------


----------------------------------
# カスタムの活性化関数のPython関数を実装
def custom_activation(x):
    return (tf.exp(x)-tf.exp(-x))/(tf.exp(x)+tf.exp(-x))

# カスタムの活性化関数クラスを実装(レイヤーのサブクラス化)
# (tf.keras.layers.Activation()の代わり)
class CustomActivation(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super(CustomActivation, self).__init__(**kwargs)

  def call(self, inputs):
    return custom_activation(inputs)
-------------------------------------

-------------------------------------
# 正解かどうかを判定する関数を実装
def custom_matches(y_true, y_pred):         # y_trueは正解、y_predは予測(出力)
  threshold = tf.cast(0.0, y_pred.dtype)              # -1か1かを分ける閾値を作成
  y_pred = tf.cast(y_pred >= threshold, y_pred.dtype) # 閾値未満で0、以上で1に変換
  # 2倍して-1.0することで、0/1を-1.0/1.0にスケール変換して正解率を計算
  return tf.equal(y_true, y_pred * 2 - 1.0) # 正解かどうかのデータ(平均はしていない)

# カスタムの評価関数クラスを実装(サブクラス化)
# (tf.keras.metrics.BinaryAccuracy()の代わり)
class CustomAccuracy(tf.keras.metrics.Mean):
  def __init__(self, name='custom_accuracy', dtype=None):
    super(CustomAccuracy, self).__init__(name, dtype)

  # 正解率の状態を更新する際に呼び出される関数をカスタマイズ
  def update_state(self, y_true, y_pred, sample_weight=None):
    matches = custom_matches(y_true, y_pred)
    return super(CustomAccuracy, self).update_state(
        matches, sample_weight=sample_weight) # ※平均は内部で自動的に取ってくれる
----------------------------------


投稿者 muuming : 2020年06月19日 05:35