CUDAで連続ウェーブレット変換

計算時間のかかる連続ウェーブレット変換をCUDAで実装してみます (CUDAのSDKに離散ウェーブレット変換のサンプルあり).

連続ウェーブレット変換

ウェーブレット変換(wavelet transform)はマザーウェーブレットと呼ばれる基本参照波を拡大縮小,平行移動したものを用いて, 元のデータf(t)を時間(もしくは位置)と周波数(スケーリング)に関する成分W(a,b)に変換します.

eq_wt.jpg

マザーウェーブレット(もしくはウェーブレット関数)の代表的なものには,

  • Haar wavelet
    eq_haar.jpg
  • Mexican hat wavelet
    eq_mexicanhat.jpg
    などがあります.

CUDAでの実装

1次元の信号f(t)を入力として,そのウェーブレット変換W(a,b)を計算します. 入力データfは離散化されたデータf_iとして,配列に格納されていることとします. また,今回は実数データのみで検証しますが,実装は虚数値にも対応しています. ウェーブレット関数にはMexican hatを用います.

メモリ確保とデータ転送

ホストメモリ上の入力データをデバイスメモリに転送します.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
    float *dSim, *dSre, *dWre, *dWim;
    int size;
 
    // デバイスメモリの確保とホストからの転送
    size = ntrans*sizeof(float);
    cutilSafeCall(cudaMalloc((void**)&dSre, size));
    cutilSafeCall(cudaMalloc((void**)&dSim, size));
    cutilSafeCall(cudaMemcpy((void*)dSre, (void*)hFi, size, cudaMemcpyHostToDevice));
    cutilSafeCall(cudaMemset((void*)dSim, 0, size));
 
    size = ntrans*nscale*sizeof(float);
    cutilSafeCall(cudaMalloc((void**)&dWre, size));
    cutilSafeCall(cudaMalloc((void**)&dWim, size));
    cutilSafeCall(cudaMemset((void*)dWre, 0, size));
    cutilSafeCall(cudaMemset((void*)dWim, 0, size));

入力データのサイズをntrans, 周波数方向の変換解像度をnscaleとしています. dSreとdSimがそれぞれ入力データの実部と虚部,dWreとdWimには変換後の値が入ります.

カーネル

変換カーネルは,

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
__device__
float MexicanHat(float t)
{
    t = t*t;
    return MEXICAN_HAT_C*(1.0-t)*exp(-t/2.0);
}
__device__
float MexicanHatIm(float t)
{
    return 0.0f;
}
 
__global__
void cwt(float *src_re, float *src_im, float *wt_re, float *wt_im, 
         int nt, int ns, float t0, float s0, float dt, float ds, int res)
{
    int i = blockIdx.x*blockDim.x+threadIdx.x;
    int j = blockIdx.y*blockDim.y+threadIdx.y;
 
    if(i < nt && j < ns){
        float s = (s0+j*ds);    // スケーリング(周波数)
        if(s == 0.0f) s = 1e-10;
        float t = (t0+i*dt);    // 平行移動(位置)
 
        float wr = 0.0;
        float wi = 0.0;
 
        float kstep = 1.0/res;
        float k;
        for(k = 0.0; k < nt; k += kstep){
            float T = (k*dt-t)/s;
 
            wr += src_re[(uint)k]*MexicanHat(T)+src_im[(uint)k]*MexicanHatIm(T);
            wi += src_re[(uint)k]*MexicanHat(T)-src_im[(uint)k]*MexicanHatIm(T);
        }
 
        wt_re[j*nt+i] = wr/(sqrt(s)*ivalp);
        wt_im[j*nt+i] = wi/(sqrt(s)*ivalp);
    }
}

引数として,入力データ,出力先,時間(位置)方向と周波数方向の解像度,初期値,ステップ幅を指定しています. resはウェーブレット関数合成時の分解数です. そのまんま実装したので,シェアードメモリやテクスチャメモリなどを使えばもっと速くなりそうですが...

カーネル呼び出し側は,

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
    dim3 block(BLOCK_SIZE, BLOCK_SIZE);
    dim3 grid((ntrans+block.x-1)/block.x, (nscale+block.y-1)/block.y);
 
    cwt<<< grid, block >>>(dSre, dSim, dWre, dWim, ntrans, nscale, x0, y0, dx, dy, 4);
 
    // カーネル実行エラーのチェック
    cutilCheckMsg("Kernel execution failed");
 
    // GPUスレッドが終わるのを待つ
    cutilSafeCall(cudaThreadSynchronize());

ホストへの結果の転送と後片付け

結果をホストに転送し,確保したメモリを解放します.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
    // デバイスからホストへ結果を転送
    size = ntrans*nscale*sizeof(float);
    cutilSafeCall(cudaMemcpy(hWre, dWre, size, cudaMemcpyDeviceToHost));
    cutilSafeCall(cudaMemcpy(hWim, dWim, size, cudaMemcpyDeviceToHost));
 
    // デバイスメモリ解放
    cutilSafeCall(cudaFree(dSre));
    cutilSafeCall(cudaFree(dSim));
    cutilSafeCall(cudaFree(dWre));
    cutilSafeCall(cudaFree(dWim));

結果

CPU用にも実装し(OpenMPで2スレッド並列で実行),計算時間を比較します. CPU側の実装は計算の最適化として,ウェーブレット関数値の前計算,畳み込み時の範囲指定などを行っています. 入力データとして,周波数が異なるsin波の合成波を用い, データ数は時間方向に256,512,1024,周波数方向に128で計測しました.

2565121024
CPU102300694
GPU111842

単位はミリ秒です.デバイスメモリの確保,転送を含めるとGPUだと100msぐらいになってしまいます. 気になったので調べてみると,最初のcudaMallocで40-80msぐらいかかっています. その後のcudaMemcpyやcudaMallocは1msぐらいなのでだいぶ遅いです.これは, CUDA Programming Guideの3.2章の最初の方に

There is no explicit initialization function for the runtime; it initializes the first time a runtime function is called.

とあるように,最初にランタイム関数が呼ばれたときに初期化が行われるためだと思われます. 今回は1回のみの計測でしたが,何回も計測して平均を取るときは注意した方がよいかも知れません.

変換結果を以下に示します.

wavelet_trans.jpg

添付ファイル: fileeq_wt.jpg 641件 [詳細] fileeq_mexicanhat.jpg 575件 [詳細] filewavelet_trans.jpg 679件 [詳細] fileeq_haar.jpg 633件 [詳細]

トップ   編集 凍結 差分 バックアップ 添付 複製 名前変更 リロード   新規 一覧 単語検索 最終更新   ヘルプ   最終更新のRSS
Last-modified: 2011-10-27 (木) 15:09:13 (3251d)