私、foldの名に恥じない立派な関数になるから!!

昨日のは、ただ単に定数を計算してるだけで、foldと名乗っていいものかどうかアレだったので、定数畳み込みについて考える。
とりあえず、javacが糞だと言うためだけに、以下のコードをコンパイルする。

class Fold {
	public static int f( int x ) {
		return x + 1 + 2;
	}
}

このコードはどういうバイトコードになるかご存知だろうか。(このネタは前に書いたような気がするけど…)

public static int f(int);
  Code:
   0:	iload_0
   1:	iconst_1
   2:	iadd
   3:	iconst_2
   4:	iadd
   5:	ireturn

なんと、驚くべきことに、わざわざ定数1と2をロードして加算してるのだ。なんてこった!この糞がッ!

def f( x: int ) -> int:
	return x + 1 + 2
0:	LOAD_CONST	2, 
2:	LOAD_ARG	0, 
4:	LOAD_CONST	1, 
6:	ADD_I	
7:	ADD_I	

まあ、a24zcも人のこと言えないんだけど。というか、こっちのほうがスタック消費が一個多いのでむしろ悪い。(あと、javacの後ろには強力なJITがあるわけで、もっとよくない)


と、いうのはいいとして、

class Fold {
	public static int f( int x ) {
		return 1 + 2 + x;
	}
}

これだと、

public static int f(int);
  Code:
   0:	iconst_3
   1:	iload_0
   2:	iadd
   3:	ireturn

ちゃんとこうなるんだけど。


これは、構文解析の方法が問題なのだ。加算は左結合なので、

x + 1 + 2 = (x+1) + 2
1 + 2 + x = (1+2) + x

こういうふうに構文解析される。こいつを簡単に定数畳み込みやろうとすると、

(x+1) + 2  -> ……? (x+1) は定数じゃないのでどうしようもない
(1+2) + x  -> 3 + x  (1+2) は定数なので、3にしてしまう

こうなってしまうわけだ。


そこで、構文解析の方法をちょっと以下のようにしてみた。もっとうまい方法はあるかもしれんが。


まず、算術式はただの木にしないで、全部、ax+bの形にしておく

3 -> (0*0 + 3)           ;; 定数式は a=0 の b=cとする
x -> (x*1 + 0)           ;; 変数は a=1 x=変数 b=0とする。
func() -> (func()*1 + 0) ;; 関数呼び出しも一緒

こんな感じ。あとはひたすらパターンで。

x + 1 + 2 -> 
  ((x*1 + 0) + (0*0 + 1)) + 2 ->  ;;片方の係数が0だったら、定数項を足すだけ
     (x*1 + 1) + (0*0 + 2) -> 
       (x*1 + 3) -> x + 3  ;; 左結合でもできる。

係数が同じ時は足せる。

(x+3) + (y+5) ->
  (x*1 + 3) + (y*1 + 5) ->
    ((x+y)*1 + 3+5) ->
      (x+y+8)

あー。項が同じときは足せるな…memo。

x*3 + x*2 ->
  ((x*3 + 0) + (x*2 + 0)) ->
     x*3 + x*2
;;; memo: x*5 にする。
let add l r =
  match swap_if_zero l r with
      ( FloatExpr(_,0.0,al), (FloatExpr(er,mr,ar)) ) -> FloatExpr( er, mr, ar+.al )
    | ( FloatExpr(el,ml,al), (FloatExpr(er,mr,ar)) ) ->
	if ml = mr then
	  (FloatExpr ( genadd el er,ml,ar+.al) )
	else
	  (FloatExpr (genadd (genmul el (constf ml)) (genmul er (constf mr)),
		      1.0,
		      ar+.al ) )
    | ( IntExpr(el,0,al), ( IntExpr(er,mr,ar)) ) -> IntExpr( er, mr, ar+al )
    | ( IntExpr(el,ml,al), (IntExpr(er,mr,ar)) ) ->
	if ml = mr then
	  (IntExpr ( genadd el er,ml,ar+al) )
	else
	  (IntExpr ( genadd (genmul el (consti ml)) (genmul er (consti mr)),
		     1,
		     ar+al ) )

    | ( FloatExpr(el,0.0,al), VarExpr(er) ) -> FloatExpr( er, 1.0, al )
    | ( IntExpr(el,0,al), VarExpr(er) ) -> IntExpr( er, 1, al )

    | ( FloatExpr(el,1.0,al), VarExpr(ev) ) -> FloatExpr( genadd el ev, 1.0, al )
    | ( IntExpr(el,1,al), VarExpr(ev) ) -> IntExpr( genadd el ev, 1, al )

    | ( FloatExpr(el,mul,al), VarExpr(ev) ) -> FloatExpr( genadd (genmul el (constf mul)) ev,
							  1.0, al )
    | ( IntExpr(el,mul,al), VarExpr(ev) ) -> IntExpr( genadd (genmul el (consti mul)) ev,
						      1, al )

    | ( VarExpr(el), VarExpr(er) ) -> VarExpr( genadd el er )
    | ( VarExpr _, _ ) -> raise NulExprException
    | ( IntExpr(_,_,_), ( FloatExpr(_,_,_) )) -> raise TypeUnmatch
    | ( FloatExpr(_,_,_), ( IntExpr(_,_,_) )) -> raise TypeUnmatch

なんか、もうちょいやりようがあると思うけど…大体こんな感じ。

8+4+x -> x+12
x+8+4 -> x+12

このぐらいまではできた。減算をやるには、

a-b = a+(-b)
    = (a*1 + 0) + (b*-1 + 0)

これでいける…はず。あー、そうか。
こうする利点は、左右を反転できること。

a-b != b-a
a+(-b) == (-b)+a

うーん。なんかただのメモっぽくなってしまった…まあいいか。
とりあえずこのペースでいくと、構文解析が終わるまでに飽きそうな気が…やりすぎないようにしないと。