2014年6月7日土曜日

魔方陣(CodeIQ)

CodeIQで、Short Coder @ozy4dm Ozyさんの出題「魔方陣ヌルヌル」を解いた。1列の和が0となることを除いては普通の魔方陣とルールは同じ。
連立一次方程式に帰着して力技で解くことにした。理由は、魔方陣の解き方について全く知らなかったこと(効率的に空白を埋めていく方法があるらしい)と、連立一次方程式の解法について復習したかったから。

求め方は、まず各マスに番号を振る。たとえば3×3の魔方陣の場合はこんな感じ。

123
456
789

それぞれのマスに入る数字をx(1)…x(n^2)とすると、以下のようなn^2元連立一次方程式を立てることができる。

x(1)+x(2)+…x(n)=0
x(n+1)+x(n+2)+…+x(2n)=0

x(n(n-1)+1)+x(n^2-n+2)+…+x(n^2)=0

x(1)+x(n+1)+…+x(n(n-1)+1)=0
x(2)+x(n+2)+…+x(n(n-1)+2)=0

x(n)+x(2n)+…+x(n^2)=0

x(1)+x(n+2)+…+x(n^2)=0
x(n)+x(2n-1)+…+x(n(n-1)+1)=0

これを2n+2行n^2列の行列Aを用いてAx=0としてガウス・ジョルダン法で解く。ただし、rank(A)=n^2にはならない。ちゃんと確認していないがrank(A)=2n+1になるもよう。rank(A)≠n^2のとき解はn^2-rank(A)個の未知のパラメータを含むので、このパラメータをx(1)…x(n^2)適当に選ぶ。
解き方は以上だが、魔方陣の場合は各数字が整数であるという制約があるので、ガウス・ジョルダン法をそのまま適用するのではなく、手順の途中で小数が入り込まないように工夫する必要がある。また、パラメータも解が整数になるように選ぶ必要がある(本来はちゃんと探索するべきだろうけど今回は適当に決めたらうまくいったのでプログラム中では決め打ちになっている)。

以下ソースコード。使用言語はFortran。前進消去と交代代入を別のプログラムにした。

ソースコード1(ガウス・ジョルダン法の前進消去までを行う)
!****************************************************************************
!
    program msquare1
!
!****************************************************************************


    implicit none

    integer, parameter :: nstde = 0

    ! 変数
    integer :: m
    ! m : 魔方陣のサイズ(1辺に含まれる数字の個数)
    integer, allocatable :: a(:,:), b(:), x(:)
    ! a(2*m+2,m**2) : 連立一次方程式の係数
    ! b(2*m+2) : 連立一次方程式の右辺
    ! x(m**2) : 解(1次元表記)
    !
    ! 1次元配列xと魔方陣yの関係
    ! x(k) = y(i, j)
    ! k = (i-1)*m+j
    ! i = k/m+1
    ! j = mod(k,m)
   
    integer, allocatable :: ix(:)
    integer :: rank
    ! ix(m**2) : 列の入れ替え情報
    ! rank : 行列のランク
   
    character(len=100) :: ofile
    ! ofile : 出力ファイル名
   
    integer :: io
   
    write(nstde, *)'# m = ?'
    read(*, *)m
   
    allocate(a(2*m+2, m**2), b(2*m+2), x(2*m+2))
   
    ! --- a, bを生成する ---
    call struct_a(m, a)

!    b = (m**2)*(m**2+1)/2/m ! 普通の魔方陣の条件
    b = 0 ! ヌルヌルした魔方陣の条件(右辺=0)
   
    write(nstde, *)'# --- a ---'
    call print_mat(nstde, 2*m+2, m**2, a)
    write(nstde, *)'# --- b ---'
    call print_mat(nstde, 2*m+2, 1, b)
   
    write(nstde, *)'# processing...'
   
    ! aのランクを求める
    allocate(ix(m**2))
    call calc_rank(2*m+2, m**2, a, b, ix, rank)
       
    ! 結果を出力する
    write(ofile, 1000)m
    1000 format('matrix',i2.2,'.txt')
    open(11, file = trim(ofile), status = 'unknown', iostat = io)
    if(io /= 0)then
        write(nstde, *)'# can not open ' // trim(ofile) // '.'
        stop
    endif
    write(11, *)'# --- m ---'
    write(11, *) m
    write(11, *)'# --- rank ---'
    write(11, *) rank
    write(11, *)'# --- a ---'
    call print_mat(11, 2*m+2, m**2, a)
    write(11, *)'# --- b ---'
    call print_mat(11, 2*m+2, 1, b)
    write(11, *)'# --- ix ---'
    call print_mat(11, m**2, 1, ix)
    close(11)
   
    write(nstde, *)'# --- m ---'
    write(nstde, *) m
    write(nstde, *)'# --- rank ---'
    write(nstde, *) rank
    write(nstde, *)'# --- a ---'
    call print_mat(nstde, 2*m+2, m**2, a)
    write(nstde, *)'# --- b ---'
    call print_mat(nstde, 2*m+2, 1, b)
    write(nstde, *)'# --- ix ---'
    call print_mat(nstde, m**2, 1, ix)
   
    end program msquare1

!****************************************************************************
    subroutine struct_a(m, a)
!****************************************************************************
    implicit none
    integer, intent(in) :: m
    integer, intent(inout) :: a(2*m+2, m**2)
   
    integer :: i, j, k
   
    ! --- 初期化 ---
    a=0
    ! --- 横 ---
    do i = 1, m
        do j = 1, m
            k = (i-1)*m+j
            a(i, k) = 1
        enddo
    enddo
   
    ! 縦
    do j = 1, m
        do i = 1, m
            k = (i-1)*m+j
            a(j+m, k) = 1
        enddo
    enddo
   
    ! 斜め
    do i = 1, m
        k = (i-1)*m+i
        a(2*m+1, k) = 1
    enddo
    do i = 1, m
        k = (i-1)*m+(m-i+1)
        a(2*m+2, k) = 1
    enddo
    end subroutine struct_a
   
!****************************************************************************
    subroutine calc_rank(m, n, a, b, ix, rank)
!****************************************************************************
    implicit none
    integer, parameter :: nstde = 0
    integer, intent(in) :: m, n
    integer, intent(inout) :: a(m, n), b(m), ix(n), rank
   
    integer :: i, j, ii, jj, tmp
   
    do i = 1, n
        ix(i) = i
    enddo

l1: do i = 1, m
        do jj = i, n
        do ii = i, m
            if(a(ii, jj) /= 0)then
                call chgrow(m, n, a, b, i, ii)
                call chgcol(m, n, a, ix, i, jj)
                call setcol0(m, n, a, b, i)
                call normalize2(m, n, a, b)
                cycle l1
            endif
        enddo
        enddo
        rank =  i-1
        return
    enddo l1

    end subroutine calc_rank
   
!****************************************************************************
    subroutine chgrow(m, n, a, b, i1, i2)
!****************************************************************************
    implicit none
    integer, intent(in) :: m, n
    integer, intent(inout) :: a(m, n), b(m)
    integer, intent(in) :: i1, i2
   
    integer, allocatable :: rtmp(:)
    integer :: tmp
   
    allocate(rtmp(n))
    rtmp = a(i1, :)
    a(i1, :) = a(i2, :)
    a(i2, :) = rtmp
    deallocate(rtmp)
    tmp = b(i1)
    b(i1) = b(i2)
    b(i2) = tmp  
    end subroutine chgrow

!****************************************************************************
    subroutine chgcol(m, n, a, ix, j1, j2)
!****************************************************************************
    implicit none
    integer, intent(in) :: m, n
    integer, intent(inout) :: a(m, n), ix(n)
    integer, intent(in) :: j1, j2
   
    integer, allocatable :: rtmp(:)
    integer :: tmp
   
    allocate(rtmp(m))
    rtmp = a(:, j1)
    a(:, j1) = a(:, j2)
    a(:, j2) = rtmp
    deallocate(rtmp)
    tmp = ix(j1)
    ix(j1) = ix(j2)
    ix(j2) = tmp
    end subroutine chgcol

!****************************************************************************
    subroutine setcol0(m, n, a, b, i)
!****************************************************************************
    implicit none
    integer, intent(in) :: m, n
    integer, intent(inout) :: a(m, n), b(m)
    integer, intent(in) :: i
   
    integer :: ii, j, tmp
   
    do ii = 1, m
        if(ii == i)cycle
        tmp = a(ii, i)
        do j = 1, n
            a(ii, j) = a(i, i)*a(ii,j) - tmp*a(i, j)
        enddo
        b(ii) = a(i, i)*b(ii) - tmp*b(i)
    enddo
    end subroutine setcol0

   
!****************************************************************************
    subroutine normalize(m, n, a, x)
!****************************************************************************
    implicit none
    integer, intent(in) :: m, n
    integer, intent(inout) :: a(m, n), x(m)

    integer :: i, j
    integer :: tmp, tmp2
   
 l2:do i = 1, min(m, n)
        if(a(i, i) /= 0)then
         l1:do tmp2 = a(i, i), 1, -a(i, i)/abs(a(i, i))
                do j = 1, n
                    if(mod(a(i, j), tmp2) /=0)cycle l1
                enddo
                if(mod(x(i), tmp2) /= 0)cycle l1
                do j = 1, n
                    a(i, j) = a(i, j) / tmp2
                enddo
                x(i) = x(i) / tmp2
                cycle l2
            enddo l1
        endif
    enddo l2
    end subroutine normalize

!****************************************************************************
    subroutine normalize2(m, n, a, x)
!****************************************************************************
    implicit none
    integer, intent(in) :: m, n
    integer, intent(inout) :: a(m, n), x(m)

    integer :: i, j
    integer :: tmp, tmp2
   
 l2:do i = 1, m
        tmp = maxval(abs(a(i,:)))
        if(tmp /= 0)then
         l1:do tmp2 = tmp, 1, -tmp/abs(tmp)
                do j = 1, n
                    if(mod(a(i, j), tmp2) /=0)cycle l1
                enddo
                if(mod(x(i), tmp2) /= 0)cycle l1
                do j = 1, n
                    a(i, j) = a(i, j) / tmp2
                enddo
                x(i) = x(i) / tmp2
                cycle l2
            enddo l1
        endif
    enddo l2
    end subroutine normalize2
   
!****************************************************************************
    subroutine print_mat(unit, m, n, a)
!****************************************************************************
    implicit none
    integer, parameter ::  nstde = 0
    integer, intent(in) :: unit, m, n, a(m, n)
   
    integer :: i, j

    do i = 1, m
        write(unit, 1000)(a(i, j), j=1, n)
        1000 format(100(1x, i4))
    enddo
    end subroutine print_mat

ソースコード2(ガウス・ジョルダン法の交代代入以降を行う)
!****************************************************************************
!
    program msquare2
!
!****************************************************************************
    implicit none
    integer, parameter :: nstde = 0

    ! 変数
    integer :: m
    ! m : 魔方陣のサイズ(1辺に含まれる数字の個数)
    integer, allocatable :: a(:,:), b(:), x(:), y(:,:)
    ! a(2*m+2,m**2) : 連立一次方程式の係数
    ! b(2*m+2) : 連立一次方程式の右辺
    ! x(m**2) : 解(1次元表示)
    ! y(m, m) : 解(2次元表示)
    !
    ! 1次元配列xと魔方陣yの関係
    ! x(k) = y(i, j)
    ! k = (i-1)*m+j
    ! i = k/m+1
    ! j = mod(k,m)
   
    integer, allocatable :: ix(:)
    integer :: rank
    ! ix(m**2) : 列の入れ替え情報
    ! rank : 行列のランク
   
    character(len=100) :: infile, ofile

    ! infile : msquare1の出力ファイル

    integer :: i, j, io

    write(nstde, *)'# infile, ofile=?'
    read(*, *)infile, ofile
   
    open(11, file = trim(infile), status = 'unknown', iostat = io)
    if(io /= 0)then
        write(nstde, *)'# can not open ' // trim(infile) //'.'
        stop
    endif
    read(11, *) ! skip header
    read(11, *)m
    write(nstde, *)'# m=', m
   
    allocate(a(2*m+2, m**2), b(2*m+2), x(m**2), ix(m**2))
   
    read(11, *) ! skip header
    read(11, *)rank
    write(nstde, *)'# rank=', rank
   
    read(11, *) ! skip header
    do i = 1, 2*m+2
        read(11, *)(a(i, j), j = 1, m**2)
    enddo
    write(nstde, *)'# --- a ---'
    call print_mat(nstde, 2*m+2, m**2, a)  

    read(11, *) ! skip header
    do i = 1, 2*m+2
        read(11, *)b(i)
    enddo
    write(nstde, *)'# --- x ---'
    call print_mat(nstde, 2*m+2, 1, b)
       
    read(11, *) ! skip header
    do i = 1, m**2
        read(11, *)ix(i)
    enddo

    write(nstde, *)'# --- x ---'
    call print_mat(nstde, m**2, 1, ix)

    close(11)

    call calc_x(2*m+2, m**2, rank, a, b, ix, x)
   
    call check_x(m**2, x)

    allocate(y(m, m))
    do i = 1, m
    do j = 1, m
        y(i, j) = x((i-1)*m+j)
    enddo
    enddo
   
    call check_y(m, y)
   
    write(nstde, *)'# --- answer ---'
    call print_mat(nstde, m, m, y)
    open(12, file = trim(ofile), status = 'unknown', iostat = io)
    if(io /= 0)then
        write(nstde, *)'# can not open ' // trim(ofile) // '.'
        stop
    endif
    call print_mat(12, m, m, y)
    close(12)  
    end program msquare2

!****************************************************************************
    subroutine calc_x(m, n, rank, a, b, ix, x)
!****************************************************************************
    implicit none
    integer, parameter :: nstde = 0
   
    integer, intent(in) :: m, n, rank, a(m, n), b(m), ix(n)
    integer, intent(out) :: x(n)
   
    integer, allocatable :: x_tmp(:)
    integer :: i, j, ii
   
    ! --- x_tmp(rank+1)~x_tmp(n)の生成(解が整数になるように適当に与える) ---
    allocate(x_tmp(n))
    x_tmp=0
    ii = 0
    do i = rank+1, n
        ii = ii + 2
        x_tmp(i) = ii
    enddo
   
    ! --- x_tmp(1)~x_tmp(rank)の生成 ---
    do i = 1, rank
        x_tmp(i) = b(i)
        do j = rank + 1, n
            x_tmp(i) = x_tmp(i) - a(i, j) * x_tmp(j)
        enddo
        if(mod(x_tmp(i), a(i, i)) /= 0)then
            write(nstde, *)'# not an integer.'
            stop
        endif
        x_tmp(i) = x_tmp(i) / a(i, i)
    enddo

    ! --- rank計算時に入れ替えていた変数を元に戻す ---  
    do i = 1, n
        x(ix(i))=x_tmp(i)
    enddo
    end subroutine calc_x
!****************************************************************************
    subroutine check_x(m, x)
!****************************************************************************
    implicit none
    integer, parameter :: nstde = 0
    integer, intent(in) :: m, x(m)
   
    integer :: i, j
    do i = 1, m-1
    do j = i+1, m
        if(x(i) == x(j))then
            write(nstde, *)i, j, x(i)
            1000 format(' x(',i2,')=x(',i2,')=',i4)
        endif
    enddo
    enddo
    end subroutine check_x  
!****************************************************************************
    subroutine check_y(m, y)
!****************************************************************************
    implicit none
    integer, parameter :: nstde = 0
    integer, intent(in) :: m, y(m, m)
    integer :: i, j, tmp
    write(nstde,*)'# --- sum ---'
    do i = 1, m
        write(nstde,*)sum(y(i, :))
    enddo
    do j = 1, m
        write(nstde,*)sum(y(:, j))
    enddo
    tmp = 0
    do i = 1, m
        tmp = tmp+y(i, i)
    enddo
    write(nstde,*)tmp
    tmp = 0
    do i = 1, m
        tmp = tmp+y(i, m-i+1)
    enddo
    write(nstde,*)tmp
    end subroutine check_y

!****************************************************************************
    subroutine print_mat(unit, m, n, a)
!****************************************************************************
    implicit none
    integer, parameter ::  nstde = 0
    integer, intent(in) :: unit, m, n, a(m, n)
   
    integer :: i, j

    do i = 1, m
        write(unit, 1000)(a(i, j), j=1, n)
        1000 format(100(1x, i4))
    enddo
    end subroutine print_mat



0 件のコメント:

コメントを投稿