【忘却のJava#5】組み合わせ最適化問題を解くプログラムをJavaで書く(2)

やること

以下の記事で紹介したCのプログラムをJavaに書き換えます。
wing-degital.hatenablog.com

Javaへの書き換えのポイント

せっかくなのでCにはないJavaならではの方法をわざわざ使って実装しました。

classと二次元配列

どんなデータ構造?

classはオブジェクトの設計書のようなものでCの構造体のようにメンバー(変数や配列)を定義できます。
構造体と違うのはメンバーに加えてメソッド(関数)が定義できるところです。

用途は?

グラフを表現するために使いました。
頂点数と辺の集合をメンバーで持たせています。辺の集合は二次元配列で定義してます。

ArrayList

どんなデータ構造?

可変長の配列です。
配列への要素の追加、削除といった操作ができるメソッドがついていて便利です。

用途は?

辺集合とUnion-Findの部分集合を表現するために使いました。
正直、この手の組み合わせ最適化問題は固定長の集合で計算できるのでclassと配列でも表現できるのですが、無理やり使いました。

HashSet

どんなデータ構造?

要素の重複を許さない集合です。
内部的には配列とハッシュ値と線形リストの組み合わせで実現してるっぽいですね。

用途は?

Union-Findの集合の管理に使おうと思ったのですが、データ構造的に冗長過ぎるのでやめました。
配列だけで表現できます。

Iterator / 拡張for文

ArraySetはIterator、二次元配列は拡張for文でループを回してみました。
慣れてないせいなのか、どちらも直感的に分かりにくい。
特に二次元配列は拡張for文は馴染みませんね。。

Javaのコード

実際に書き換えたのがこちらです。

Graph.java

public class Graph {
    int ord;
    int [][] adj;

    public Graph(int n){
        this.ord = n;
        this.adj = new int[n][n];
    }

    public void setGraph(int input[][]){
        adj = input;
    }

    public void printGraph(){
        for (int array[] : adj){
            for (int val : array) {
                System.out.printf("%4d", val);
            }
            System.out.println("");
        }
    }

    public int countTotalEdgeCost() {
        int sum = 0;
        for (int i = 0; i < this.ord; i++) {
            for (int j = i + 1; j < this.ord; j++) {
                sum += this.adj[i][j];
            }
        }
        return sum;
    }
}

EdgeList.java

import java.util.ArrayList;
import java.util.Iterator;

public class EdgeList {

    static ArrayList<Edge> edgeList = new ArrayList<Edge>();

    public EdgeList(Graph g){
        for (int i = 0; i < g.ord; i++) {
            for (int j = i + 1; j < g.ord; j++) {
                if(g.adj[i][j]> 0) {
                    this.addEdgeList(g.adj[i][j], i, j);
                }
            }
        }
    }

    public void addEdgeList(int cost, int src, int dest) {
        Edge ed = new Edge(cost, src, dest);
        edgeList.add(ed);
    }

    public void printEdgeList() {
        for (Iterator<Edge>itr = edgeList.iterator(); itr.hasNext();) {
            Edge e = itr.next();
            e.printEdge();
        }
    }

    public void swapEdge(int i, int j) {
        Edge temp = new Edge(0, 0, 0);
        temp = edgeList.get(i);
        edgeList.set(i, edgeList.get(j));
        edgeList.set(j, temp);
    }

    public void quickSortEdgeList(int low, int high) {
        int mid = (low + high) / 2;
        int i = low;
        int j = high;
        double pivot = edgeList.get(mid).cost;
        while (true) {
            while (edgeList.get(i).cost < pivot) { i++; }
            while (edgeList.get(j).cost > pivot) { j--; }
            if (i >= j) { break; }
            swapEdge(i++, j--);
        }
        if (low < i - 1 ) { quickSortEdgeList(low, i - 1); }
        if (j+1 < high ) { quickSortEdgeList(j + 1, high); }
    }
    class Edge {
        int cost;
        int src;
        int dest;

        public Edge(int cost, int src, int dest) {
            this.cost = cost;
            this.src = src;
            this.dest = dest;
        }

        public void printEdge() {
            System.out.printf("%d<->%d (cost:%d)%n", this.src, this.dest, this.cost);
        }

    }
}

mst.java

import java.util.ArrayList;
import java.util.Iterator;

class Kruskal {

    public Graph execKruskal(Graph g, Graph tree) {
        EdgeList edl = new EdgeList(g);
        FlagmentList flml = new FlagmentList(g.ord);
        int edge_num = edl.edgeList.size();
        int src = 0;
        int dest = 0;
        int cnt_edge = 0;
        edl.quickSortEdgeList(0, edge_num-1);
        for (int i = 0; i < edge_num; i++) {
            src = edl.edgeList.get(i).src;
            dest = edl.edgeList.get(i).dest;
            // find if cycle exist or not
            if (flml.findParent(src) != flml.findParent(dest)) {
                // add edge in MST
                tree.adj[src][dest] = g.adj[src][dest];
                tree.adj[dest][src] = g.adj[dest][src];
                // marge set
                flml.Union(flml.flmList.get(src).parent, flml.flmList.get(dest).parent);
                cnt_edge++;
            }
            if (cnt_edge == (g.ord -1)) { break; }
        }
        //edl.printEdgeList();
        return tree;
    }

    class FlagmentList {

        ArrayList<Flagment> flmList = new ArrayList<Flagment>();

        public FlagmentList(int ord) {
            for (int i = 0; i < ord; i++) {
                this.addFlagmentList(i, 0);
            }
        }

        public void addFlagmentList(int parent, int rank) {
            Flagment flm = new Flagment(parent, rank);
            flmList.add(flm);
        }
        int findParent(int x) {
            return flmList.get(x).parent;
        }

        void Union(int x, int y) {
            Flagment flmX = flmList.get(x);
            Flagment flmY = flmList.get(y);
            if (flmX.rank > flmY.rank){
                flmY.setFlagment(flmX.parent, flmX.rank);
            } else {
                flmX.setFlagment(flmY.parent, flmY.rank);
                if (flmX.rank == flmY.rank) {
                    flmX.addRank();
                    flmY.addRank();
                }
            }
        }

        class Flagment {
            int parent;
            int rank;

            public Flagment(int parent, int rank){
                this.parent = parent;
                this.rank = rank;
            }

            public void setFlagment(int parent, int rank){
                this.parent = parent;
                this.rank = rank;
            }
           public void setFlagment(int parent, int rank){
                this.parent = parent;
                this.rank = rank;
            }

            public void addRank(){
                this.rank++;
            }

        }

    }
}
public class mst {
    public static void main(String[] args) {

        // create graph
        int maze[][] = {
            { 0,  7, 15,  0,  0,  0 },
            { 7,  0, 12,  6,  0,  0 },
            {15, 12,  0, 18, 17,  2 },
            { 0,  6, 18,  0, 13,  0 },
            { 0,  0, 17, 13,  0, 24 },
            { 0,  0,  2,  0, 24,  0 }
        };
        int n = maze.length;
        Graph g = new Graph(n); // graph
        Graph tree = new Graph(n); // graph
        int tc; // total cost

        // make graph
        g.setGraph(maze);

        // show test data
        System.out.println("Tree (result)");
        g.printGraph();

        // make MST by Kruskal Method
        Kruskal kruskal= new Kruskal();
        tree = kruskal.execKruskal(g, tree);
        // show result
        System.out.println("Tree (result)");
        tree.printGraph();
        tc = tree.countTotalEdgeCost();
        System.out.printf("total cost:%d\n", tc);
    }
}

実行結果

前回と同じように最適解が得られました。

# java mst
Graph (problem)
   0   7  15   0   0   0
   7   0  12   6   0   0
  15  12   0  18  17   2
   0   6  18   0  13   0
   0   0  17  13   0  24
   0   0   2   0  24   0

Tree (result)
   0   7   0   0   0   0
   7   0  12   6   0   0
   0  12   0   0   0   2
   0   6   0   0  13   0
   0   0   0  13   0   0
   0   0   2   0   0   0
total cost:40

参考

Javaコードで書くにあたり、こちらの記事でお勉強させていただきました。

qiita.com

qiita.com

www.moriwaki.net

以上。