種型の練習

ousttrue2009-01-27

球の法線計算まで実装することができた。
前回つくりかけのベクターは当初想定していた程度のものは完成。
早速組み込んで使ってみた。
正規化忘れをコンパイル時に検出できるのは悪くない感じ。
C++のエクスプレッションテンプレート(よく知らないw)のような
感じで最適化に役立ちそうな気もするがどうなんだろう。


haskellの型システムだと、インターフェースを定義するのがclass typeで
メンバを定義するのがdata typeという感じになると理解。
実装はclassタイプを継承してinstanceで。
もうひとつ型の型を規定するseed(種)というのがあったので
今回はこれを作ってみた。
種は、文法的に特別な識別子があるわけでなく文脈からそうなるというものらしく
辻褄を合わせるのが難航。
normalizeをメンバからはずすことで妥協した。
型変数を使ったGenericな定義を作るには慣れが必要らしい。


あとでモナドをマスターしたら華麗に作り直す予定。
単位ベクトルのデータコンストラクタをエクスポートしないという細工がしてある。

module Vector(Vec, sMul, sDiv, dot, sqareNorm, cross, 
  normalize, unnormalize,
  Vector(Vector), UnitVector) where 

--------------------------------------------------------------------------------
-- class type Vec
--------------------------------------------------------------------------------
class Vec v where
  sMul, sDiv ::(Floating f)=>v f->f->v f
  dot ::(Floating f)=>v f->v f->f
  sqareNorm, norm ::(Floating f)=>v f->f
  cross ::(Floating f)=>v f->v f->v f
  -- default implementation
  sDiv v x=sMul v (1/x)
  sqareNorm v=dot v v
  norm v=sqrt $ sqareNorm v

--------------------------------------------------------------------------------
-- data type Vector
--------------------------------------------------------------------------------
data (Floating f)=>Vector f=Vector [f]

instance (Floating f)=>Show (Vector f) where
  show (Vector vector)=show vector

instance (Floating f)=>Eq (Vector f) where
  (==) (Vector lhs) (Vector rhs)=all (\ (l, r)->l==r) $ zip lhs rhs

instance (Floating f)=>Num (Vector f) where
  (+) (Vector lhs) (Vector rhs)=Vector $ zipWith (+) lhs rhs
  (-) (Vector lhs) (Vector rhs)=Vector $ zipWith (-) lhs rhs
  (*) _ _=error "not defined"
  abs _ =error "not defined"
  signum _ =error "not defined"
  fromInteger _ =error "not defined"

instance Vec Vector where
  sMul (Vector vector) scalar=Vector $ map ((*) scalar) vector
  dot (Vector lhs) (Vector rhs)=foldr (+) 0 [x*y|(x,y)<-zip lhs rhs]
  cross (Vector [lx, ly, lz]) (Vector [rx, ry, rz])=
    Vector $ [ly*rz-lz*ry, lz*rx-lx*rz, lx*ry-ly*rx]
  cross _ _=error "not defined"

--------------------------------------------------------------------------------
-- data type UnitVector
--------------------------------------------------------------------------------
-- not export
data (Floating f)=>UnitVector f=UnitVector [f]

instance (Floating f)=>Show (UnitVector f) where
  show (UnitVector vector)=show vector

instance (Floating f)=>Eq (UnitVector f) where
  (==) (UnitVector lhs) (UnitVector rhs)=all (\ (l, r)->l==r) $ zip lhs rhs

instance (Floating f)=>Num (UnitVector f) where
  (+)  _ _=error "not defined"
  (-) _ _=error "not defined"
  (*) _ _=error "not defined"
  abs _ =error "not defined"
  signum _ =error "not defined"
  fromInteger _ =error "not defined"

instance Vec UnitVector where
  sMul (UnitVector vector) scalar=UnitVector $ map ((*) scalar) vector
  dot (UnitVector lhs) (UnitVector rhs)=foldr (+) 0 [x*y|(x,y)<-zip lhs rhs]
  cross (UnitVector [lx, ly, lz]) (UnitVector [rx, ry, rz])=
    UnitVector $ [ly*rz-lz*ry, lz*rx-lx*rz, lx*ry-ly*rx]
  cross _ _=error "not defined"
  -- must be normalized
  sqareNorm v=1

--------------------------------------------------------------------------------
-- convert
--------------------------------------------------------------------------------
normalize::(Floating f)=>Vector f->UnitVector f
normalize vector=UnitVector normalized
  where
    (Vector normalized)=sDiv vector $ norm vector

unnormalize::(Floating f)=>UnitVector f->Vector f
unnormalize (UnitVector vector)=Vector vector
module VectorTest where
import Test.HUnit
import Vector

main::IO Counts
main=do
  runTestTT $ TestList [
    (TestCase $ assertEqual "Vector==Vector" (Vector [0, 0, 0]) 
      (Vector [0, 0, 0])),

    (TestCase $ assertEqual "Vector+Vector" (Vector [4, 4, 4])
      ((Vector [1, 2, 3])+(Vector [3, 2, 1]))),

    (TestCase $ assertEqual "Vector-Vector" (Vector [-1, -2, -3])
      ((Vector [0, 0, 0])-(Vector [1, 2, 3]))),

    (TestCase $ assertEqual "Vector*Scalar" (Vector [2, 4, 6])
      ((Vector [1, 2, 3]) `sMul` 2)),

    (TestCase $ assertEqual "Vector/Scalar" (Vector [1, 2, 3])
      ((Vector [2, 4, 6]) `sDiv` 2)),

    (TestCase $ assertEqual "Vector dot Vector" 14
      (dot (Vector [1, 2, 3]) (Vector[1, 2, 3]))),

    (TestCase $ assertEqual "Vector cross Vector" (Vector [0, 0, 1])
      (cross (Vector [1, 0, 0]) (Vector[0, 1, 0]))),

    (TestCase $ assertEqual "normalize Vector" (Vector [1, 0, 0])
      (unnormalize $ normalize $ Vector [3, 0, 0]))
    ]