はじめMath! Javaでコンピュータ数学

第63回統計の数学 相関係数とは[後編]

前回は相関係数の計算方法を紹介しました。今回は、相関係数の計算手順をJava言語のプログラムとして書き表すことに挑戦してみましょう。

問題 CSVファイルのデータを読み込み、相関係数を計算するプログラムを作りましょう。

第61回の問題で作成したコードに、相関係数を計算するコードを加えましょう。それだけでは退屈ですから、コマンドライン引数でデータファイルを指定し、指定されたファイルを読み込んで処理するプログラムに変更してください。なに、ほんのちょっとした変更だけで済みます。⁠Java コマンドライン 引数⁠としてGoogleで検索すれば、たくさんのサンプルが表示されるでしょう。それらから盗み取って作成してください。

解説

問題 CSVファイルのデータを読み込み、相関係数を計算するプログラムを作りましょう。

計算に用いたデータファイルdata001.csvも前回と同じものを用いましょう。

データファイル data001.csv

  • 100,120
  • 200,195
  • 278,280

それだけでは面白くありませんから、負の相関を示しそうな次のデータでも試してください。

データファイル data002.csv

  • 20,240
  • 150,225
  • 278,120

コンパイル時には、-Xlint:uncheckedというオプションを忘れずに付けてください。忘れないように、次のようなバッチファイルを作成していちいち入力する手間を省くと良いでしょう。

バッチファイルcomp.bat
del *.class
javac Sample_Correlation.java -Xlint:unchecked
java Sample_Correlation ./data001.csv

今回の課題のために追加されたメソッドは次の通りです。

getCorrelationCoefficient
getAverage
getVariance

主要な追加コードはこれだけなのですが、クラス名の変更等で広い範囲に変更部分が散らばっているため、コード全体を次ページに示します。

ソースコード:Sample Correlation.java
/*
 * filename : Sample_Correlation.java
 *            CVS ファイルを読み込んで、回帰直線の
 *            定数値を求め、グラフ化し、相関係数を求めます。
 * compile  : c:\>javac Sample_RegLine.java -Xlint:unchecked
 * usage    : c:\>java Sample_RegLine ./data002.csv
 */

import java.io.*;
import java.util.*;
import java.awt.*;


public class Sample_Correlation extends Frame {

  //データファイル名
  static String DATAFILE = "./data001.csv";

  //表示するウインドウの最大・最小座標値
  static int SCREEN_MAX_X = 300;
  static int SCREEN_MIN_X = 0;
  static int SCREEN_MAX_Y = 300;
  static int SCREEN_MIN_Y = 0;
  //
  //グラフ中にデータをプロットする丸のサイズ
  static int DATA_PLOT_OVAL_WIDTH = 4;
  static int DATA_PLOT_OVAL_HEIGHT = 4;
  //図形を描画するオブジェクト
  MyCanvas mc;
  /**
   * メインメソッド
   * 簡略のためすべてここから呼び出し
   */
  public static void main (String args []) {

    //線形最小二乗法によって得られた定数
    //y=ax+b のa,b を格納する。val[0]=a,val[1]=b
    //基本型は参照渡しできないのでこうしてみる
    double val[] = new double[2];
    //ファイルから読み込んだデータの一時保管用
    Vector v1 = new Vector();
    //ファイルから読み込んだ元データは文字列であるため、
    //数値に変換したものをこちらに格納する。
    Vector DataA = new Vector();


    //データファイル読み込み
    try{
      if (args.length ==0 ){
        readTextFromFile_AndSetVector(DATAFILE,v1);
      } else {
        readTextFromFile_AndSetVector(args[0],v1);
      }
    }
    catch(Exception e){
      System.out.println(e.toString());
      System.exit(-1);
    }// of try catch

    //一時読み込みしたデータは文字列なので数値に変換し、
    //配列にセットする
      KataHenkan(v1,DataA);
    //得られた数値データに最小二乗法を適用し、近似
    //直線の定数を得る
      SenkeiSaishouJijyouHou(DataA,val);

    //相関係数を求める
      double r = getCorrelationCoefficient(DataA);

    //標準出力に計算結果を出力
    System.out.println("y=ax+b");
    System.out.println("value of a = "+val[0]);
    System.out.println("value of b = "+val[1]);
    //相関係数を出力
    System.out.println("r = " + r);

    //ウインドウを作成し、結果を表示
    new Sample_Correlation(DataA,val).show();

  }// end of main()


  /*
   * 目的  : 相関係数を求める
   * 引数  : data 数値データの配列への参照
   * 戻り値 : 相関計数値
   */
  static double getCorrelationCoefficient(Vector data){
    //相関係数を求めるために用意する一時的な変数
    double XAve = 0; //観測値のx 成分の平均値
    double YAve = 0; //観測値のy 成分の平均値
    double XVari = 0; //x の分散
    double YVari = 0; //y の分散
    double XYVari = 0; //xy の共分散

    XAve = getAverage(data,"x");
    YAve = getAverage(data,"y");

    XVari = getVariance(data,"x",XAve,YAve);
    YVari = getVariance(data,"y",XAve,YAve);
    XYVari = getVariance(data,"xy",XAve,YAve);

    return XYVari / (Math.sqrt(XVari * YVari));
  }// end of getCorrelationCoefficient


  /*
   * 目的   : 分散や共分散を計算する
   * 引数   : data 数値データの配列への参照
   *       : axis "x" or "y" or "xy"
   *       : xave x の平均値
   *       : yave y の平均値
   * 戻り値 : 分散(または共分散)
   */
  static double getVariance(Vector data,String axis,
                            double xave,double yave){
    double xvari = 0;
    double yvari = 0;
    double xyvari = 0;

    double tempvalX = 0;
    double tempvalY = 0;
    double tempvalXY = 0;
    double x,y;
    Point temppos;
    for (int i=0; i<data.size(); ++i){
      temppos = (Point)data.get(i);
      x = (double) temppos.getX();
      tempvalX += Math.pow(x - xave,2);
      y = (double) temppos.getY();
      tempvalY += Math.pow(y - yave,2);
      tempvalXY += (x - xave) * (y - yave);
    }
    if (axis =="x") {
      return tempvalX / data.size();
    } else if (axis =="y") {
      return tempvalY / data.size();
    } else if (axis =="xy") {
      return tempvalXY / data.size();
    } else {
      return Double.NaN;
    }
  }// end of getVariance


  /*
   * 目的   : x 座標かy 座標かのデータの平均値を計算する
   * 引数   : data 数値データの配列への参照
   *       : axis "x" or "y"
   * 戻り値 : 平均値
   */
  static double getAverage(Vector data,String axis){
    double tempvalX = 0;
    double tempvalY = 0;
    double x,y;
    Point temppos;
    for (int i=0; i<data.size(); ++i){
      temppos = (Point)data.get(i);
      x = (double) temppos.getX();
      tempvalX += x;
      y = (double) temppos.getY();
      tempvalY += y;
    }
    if (axis =="x") {
      return tempvalX / data.size();
    } else if (axis =="y") {
      return tempvalY / data.size();
    } else {
      return Double.NaN;
    }
  }// end of getAverage


  /*
   * 目的 : 線形最小二乗法を実行
   * 引数 : data 数値データの配列への参照
   *       val[] 戻り値用 回帰直線の定数値
   */
  static void SenkeiSaishouJijyouHou(Vector Data,
                                     double val[]){
    double x,y,x_sum=0,y_sum=0, xx_sum=0,xy_sum=0;
    int n;
    Point temp;
    n = Data.size();
    for(int i=0;i<Data.size();i++){
      temp = (Point)Data.get(i);
      x=(double)temp.getX();
      y=(double)temp.getY();
      x_sum+=x;
      y_sum+=y;
      xx_sum+=x*x;
      xy_sum+=x*y;
    }// of for i
    val[0]= (double) (n*xy_sum-x_sum*y_sum)
           / (double) (n*xx_sum-x_sum*x_sum);
    val[1]= (double) (xx_sum*y_sum - xy_sum * x_sum)
           / (double) (n*xx_sum-x_sum*x_sum);
  }// end of SenkeiSaishouJijyouHou


  /*
   * 目的 : CSV の数値データを数値型に型変換
   * 引数 : v 文字列のデータを格納したVector
   *       Data 変換後のデータを格納したVector
   */
  static void KataHenkan(Vector v, Vector Data){
    for(int i=0; i <= (v.size()-1); i++){
      String str = (String)v.get(i);
      StringTokenizer st
                 = new StringTokenizer(str, ",");
      Point pos =
        new Point(
          Integer.parseInt((String)st.nextToken()),
          Integer.parseInt((String)st.nextToken())
        );
      Data.add(pos);
    }// of for i
  }// end of static void KataHenkan


  /*
   * 目的 : CSV ファイルから1 行ずつデータを読み込む
   * 引数 : filename データファイルのファイル名
   *       v データファイルから読み込んだデータ
   */
  static void readTextFromFile_AndSetVector
                      (String filename,Vector v) {
    try {
      FileReader fr = new FileReader(filename);
      BufferedReader br = new BufferedReader(fr);
      String rdata;
      String alldata = "";
      while((rdata = br.readLine()) != null) {
        v.add(rdata);
      }// of while
      fr.close();
    }catch(Exception e){
      System.out.println(e);
    }// of try catch
  }// of readTextFromFile_AndSetVector


  /*
   * 目的 : コンストラクタ
   * 引数 : Data プロットするデータ
   *       val 回帰直線の定数
   */
  public Sample_Correlation
                     (Vector Data,double val[]) {
    super();
    setTitle("最小二乗法のグラフをプロットする");
    setSize(SCREEN_MAX_X-SCREEN_MIN_X,
            SCREEN_MAX_Y-SCREEN_MIN_Y);
    setLayout(null);

    mc = new MyCanvas();
    mc.setBounds(SCREEN_MIN_X,SCREEN_MIN_Y, //左上隅の座標値
                 SCREEN_MAX_X,SCREEN_MAX_Y);//Width とHeight

    mc.setData(Data);
    mc.setVals(val);
    this.add(mc);
  }// end of Sample_RegLine(コンストラクタ)


  /**
   *目的 : 描画関係をまとめた。
   */
  class MyCanvas extends Canvas {

    Vector plotData;
    double val[];

    //ウインドウが再描画されるときにデータと
    //直線を再描画
    public void paint(Graphics g) {
      plotPoints();
      plotLine();
    }// end of paint

    //描画するデータへの参照を受け取る
    public void setData(Vector Data){
      plotData = Data;
    }//end of setData

    //描画する近似直線式の定数を受け取る
    public void setVals(double vals[]){
      val = vals;
    }// end of setVals

    //近似直線を描画する
    void plotLine(){
      Graphics g=getGraphics();
      g.setColor(Color.blue);
      g.drawLine(0,-(int)(val[1]) + SCREEN_MAX_Y,
        SCREEN_MAX_X,
        - (int)(SCREEN_MAX_X*val[0]+val[1]) + SCREEN_MAX_Y);
    }// end of plotLine

    //データを画面に点をプロットする。
    void plotPoints(){
      Graphics g=getGraphics();
      g.setColor(Color.red);

      Point temp = new Point();
      for (int i=0; i< plotData.size() ; i++){
        temp = (Point)plotData.get(i);
        g.drawOval((int)temp.getX(),
                   -(int)temp.getY() + SCREEN_MAX_Y,
                   DATA_PLOT_OVAL_WIDTH,
                   DATA_PLOT_OVAL_HEIGHT);
      }//end of for i
    }//end of plotPoints

  }// end of class MyCanvas


}// end of class Sample_Correlation

data001.csvを引数に渡した時の、プログラムの実行結果は次の通りです。

C:\>java Sample_Correlation ./data001.csv
value of a = 0.4660083326656518
value of b = 125.40942232192933
r = 0.9194321435674627
バッチジョブを終了しますか(Y/N)? y

前回確認した数値と比較すると、結構いい線を行っています。プログラムはちゃんと出来ているようです。良かった、良かった。では、負の相関を取るであろうdata002.csvについても処理してみましょう。定数aと相関係数rは負の値になったでしょうか?

コンピュータだけに頼らず、自分の頭脳でも確認している、という感覚を体験すると、とても自信がつくのが不思議です。ぜひともご自分の手と頭を使って取り組み、この自信ゲットしてください。

今回はここまで

今回は、相関係数の計算をJava言語で実装するところまで取り組みました。次回は相関係数がどのように組み立てられた値なのか、数学的に詳しく解説します。

おすすめ記事

記事・ニュース一覧