お疲れ様です。Leapmindの脇坂です。
この記事は、TensorFlow Advent Calendar 2017 17日目の記事です
今回は、tensorflow の `tensorflow.python.framework.function.Defun` を使って、weight の 3値化をしてみます。
Defunとは
`tensorflow.python.framework.function.Defun` はtensorflowのfunctionを定義するためのデコレータです。
今現在(tensorflow v1.4)ドキュメントにはのってない機能です。`tf.contrib.eager.defun` とは違うので注意
引数にtensorflow DataTypeをとって、任意のpython functionをデコレートし、tensorflowのgraphでそのfunctionが使えるように定義します。内部的には、たぶん、graphにfunction.protoを追加します。
tensorflowにおけるfucntionは1つ以上のoperationをまとめたものだそうです。
参考: https://www.tensorflow.org/extend/language_bindings#overview
コード例は、こんな感じです。
MyFuncの定義
@tf.Defun(tf.float32, tf.float32)
def MyFunc(x, y):
return x + y, x - y
グラフを作る時に、MyFuncを実行。
a = tf.Constant([1.0])
b = tf.Constant([2.0])
c, d = MyFunc(a, b, name='mycall')
他にオプションのキーワード引数として、
* `func_name`: Functionの名前
* `grad_func` or `python_grad_func`: backward関数
* `out_names`: output tensor の名前
* `shape_func`: output tensor のshapeを指定
が、あります。
さらに詳しく知りたい方は、ソースコードとtestを見てみてください。
https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/framework/function.py
https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/framework/function_test.py
Defunをどんな時に使いますか
さて、この`Defun`がどういったモノかはなんとなくわかりましたが、どんな時に効果的に使えるのでしょうか?
1. メモリ・計算量を減らす
2. Backwardの近似 (backwardを別の関数で近似)
この二つが考えれます。
メモリ・計算量を減らす
この使用法が一般的になるのだと思いますが、すみません、ちゃんと検証していないので自信がありません。
tensorflowでの使用例を見ると、`swish` の実装で、operationをまとめてmemory消費を少なくするために使われています。
`swish` は2、3ヶ月前に発表された(https://arxiv.org/abs/1710.05941) activation functionで、式はすごく単純
\begin{aligned}
f(x) = x \times sigmoid(x)
\end{aligned}
です。
普通にtensorflowで実装するとこうなります。
output = tf.sigmoid(x) * x
さて、deeplearningのフレームワークでは、勾配計算のためにBackpropagationを使い、そのためforward時に各operationの入力をメモリに残しておきます。
上記のように実装すると、`*` operationのために、`x` と `tf.sigmoid(x)` の両方をメモリに残しておきます。が、これは無駄です。
メモリの観点だけで見れば、 `x` だけ残しておけば良いです。
そこで、tensorflowの使用例では、 `noinline=True` を使い、下記のような実装をしています。
@function.Defun(
grad_func=_swish_grad,
shape_func=_swish_shape,
func_name="swish",
noinline=True)
def swish(x):
return x * math_ops.sigmoid(x)
注意: noinlineで本当にメモリ使用が抑えられるのか、試しておりません。。。XLAを有効にした時だけの話かも?
Backwardの近似
こちらの使用方法はレアケースになるかと思いますが、本記事のメインになります。
やっとです。
ここでニューラルネットのパラメータやアクティベーションの量子化について、さらっと振り返っておきます。
通常32bitや16bit floatのパラメータ(weight)を、メモリ削減のため2値化、3値化(ternary)、N bit化などに量子化する際、
推論時には2値や3値を使い、学習時にはパラメータの値を実数で保持しておいて実数の値を更新すると、高い精度が得られる事が知られています。
この際、forward時には実数のパラメータから量子化を行いますが、量子化を行う関数はstep関数なので、普通に勾配を求めるとうまく学習できないため、
backward時には量子化を行う関数を近似して、勾配を求めます。
参考: https://developer.smartnews.com/blog/2017/03/neural-network-quantization/
さて、tensorflow の python apiで、このようなbackward時の近似を行う際、`Graph.gradient_override_map()` や `tf.stop_gradient()` を使う方法もあるのですが、非常に読みにくコードになってしまいます。
ここで `Defun` を使ってforwardとbackwardを分けて実装すると、綺麗に書く事ができます。
パラメータの3値化
では、パラメータ(weight)の3値化(ternary)の実装例を、順をおってみていきます。
3値化の定義
まず、3値化の定義です。この記事ではわかりやすくするために適当に以下のようにします。
* スケールファクタは、パラメータの絶対値の平均
* パラメータが閾値0.5より大きい場合 +スケールファクタ
* パラメータが閾値-0.5より小さい場合に -スケールファクタ
* パラメータが0.5以下-0.5以上の場合、0
\begin{aligned}
f(x) = \left\{ \begin{array}{ll}
+ mean(|x|) & (x>0.5) \\
- mean(|x|) & (x<-0.5) \\
0 & (otherwise)
\end{array} \right.
\end{aligned}
Backwardの近似なし3値化関数
では、次に、
この3値化関数をBackwardの近似なしで、実装してみます。
def no_backward_approximate_ternary_function(weights):
"""Forward
Args:
weights(tf.Variable): The weights to be ternary.
Returns:
ternary_weights(tf.Variable): The ternary valued weights.
"""
threshold = 0.5
mask_positive = (weights > threshold)
mask_negative = (weights < -threshold)
scaling_factor = tf.reduce_mean(tf.abs(weights))
positive_weights = scaling_factor * tf.where(mask_positive, tf.ones_like(weights), tf.zeros_like(weights))
negative_weights = - scaling_factor * tf.where(mask_negative, tf.ones_like(weights), tf.zeros_like(weights))
ternary_weights = positive_weights + negative_weights
return ternary_weights
forward時のこの関数をplotしてみます。
tf.InteractiveSession()
weights_size = (100,)
np_weights = np.random.uniform(-2., 2., size=weights_size).astype(np.float32)
weights = tf.convert_to_tensor(np_weights)
no_backward_approximate_ternary_weights = no_backward_approximate_ternary_function(weights)
order = np.argsort(np_weights)
plt.plot(np_weights[order], no_backward_approximate_ternary_weights.eval()[order])
plt.ylabel("ternary weights")
plt.xlabel("input weights")
plt.title("ternary func")
plt.show()
forward時はこれで良いのですが、backward時の勾配をplotすると、
(グラフの横軸が3値化したパラメーターに関数する勾配、縦軸が実数のパラメータに関する勾配です。)
np_grad_tenary_weights = np.random.uniform(-10., 10., size=weights_size).astype(np.float32)
grad_tenary_weights = tf.convert_to_tensor(np_grad_tenary_weights)
gard_no_backward_approximate_weights, = tf.gradients(no_backward_approximate_ternary_weights, weights, grad_ys=grad_y)
order = np.argsort(np_grad_tenary_weights)
plt.plot(np_grad_tenary_weights[order], gard_no_backward_approximate_weights.eval()[order])
plt.ylabel("grad weights")
plt.xlabel("grad ternary weights")
plt.title("backward ternary func")
plt.show()
このように無茶苦茶になってしまいます。
Defunを使ったBackwardの近似ありの3値化関数
では、Defunを使ってBackwardを近似してみます。
3値関数が恒等写像であると仮定して、その上でbackwardを実装します。
このように簡潔に書けます。
@Defun(tf.float32, tf.float32)
def ternary_backward(weights, grad_ternary):
"""Backward
Args:
grad_ternary(tf.Tensor): The gradient w.r.t ternary valued weights.
Return:
grad_weights(tf.Tensor): The gradient w.r.t. normal (non-ternary) weights.
"""
grad_weights = grad_ternary
return grad_weights
@Defun(tf.float32,
grad_func=ternary_backward,
func_name="TernaryWeight",
shape_func=lambda op: [op.inputs[0].get_shape()])
def ternary_function(weights):
"""Forward
Args:
weights(tf.Variable): The weights to be ternary.
Returns:
ternary_weights(tf.Variable): The ternary valued weights.
"""
threshold = 0.5
mask_positive = (weights > threshold)
mask_negative = (weights < -threshold)
scaling_factor = tf.reduce_mean(tf.abs(weights))
positive_weights = scaling_factor * tf.where(mask_positive, tf.ones_like(weights), tf.zeros_like(weights))
negative_weights = - scaling_factor * tf.where(mask_negative, tf.ones_like(weights), tf.zeros_like(weights))
ternary_weights = positive_weights + negative_weights
return ternary_weights
3値化関数に恒等写像を仮定したので、
backwardは、3値化したパラメーターに関数する勾配と、実数のパラメータに関する勾配とが同じになります。
では、backward時の勾配をplotすると、
order = np.argsort(np_grad_tenary_weights)
plt.plot(np_grad_tenary_weights[order], gard_weights.eval()[order])
plt.ylabel("grad weights")
plt.xlabel("grad ternary weights")
plt.title("backward ternary func")
plt.show()
わーい、意図した通りになりましたー!!
これでおしまいです。
Back to Index