Python奮闘記

主にPythonのことを書くつもりだったけど、プログラミング周り全般の備忘録ということにした。大体競プロ。

ARC067-F Yakiniku Restaurants

問題

F - Yakiniku Restaurants

 N軒の店があり、それぞれ番号 1\cdots Mの肉があり、美味しさ Bが決まっている。 また、隣あう店を移動することができ、そのコスト Aが決まっている。 訪れ方と肉の選び方を工夫することで、番号 1\cdots Mの肉を集めたときの(美味しさ合計-コスト)の最大値を求める。

気持ち

区間 [l,r)の店を訪れるとするとコストが決まるので考えやすい。 区間を固定すると、各肉で独立に考えることができ、 それぞれ区間 [l,r) 内で一番美味しい肉をとってくれば良い。 ただこれだと計算量が O(N^2 M)となって間に合わない。

どうにか状態をまとめられないか考えてみる。 当たり前なのだが、肉番号 jを固定したとき、区間ごとに選ばれる肉は最大でも N種類しかない。 なので、各肉 i=1\cdots Nに対して「どの区間なら一番美味しいものとして選ばれるか」を考えると良さそう。

これはすなわち、「区間の他の要素が B[ i ] [ j ] より小さい」ということになる。もう少し区間 [l,r)として要求される特徴を整理すると、

  •  L_i B[i][j] > B[L_i][j]を満たす最小のindとしてとる。
  •  R_i B[i][j] > B[R_i-1][j]を満たす最大のindとしてとる。
  •  l=L_i .. iと各 r=i+1 ..R_i の組み合わせ [l,r)が肉 iを選ぶ区間になる。

となる。つまり各 iに対して L_i,R_iの情報さえあれば事足りる。これらはスタックを使って左右から走査することにより、全ての i=1\cdots Nに対して O(N)で求めることができる。

で、残る問題は各 iの情報をどうやって消すかということになる。 iの情報が残っていると、後で区間を固定した時に結局計算量が増えてしまうためである。 iの情報を消したものとして、それっぽい以下を考える。

 dp [ L ] [ R ] = 区間 [L, R) の部分区間、すなわち L\le l かつ r\le Rなる区間 [l, r) 全てに対して足す美味しさ

これに情報を貯めていければ、最後に区間固定して計算する際に区間を縮めながらやっていけばなんとかなりそうである。

ただ、 dp[L_i ] [ R_i ]  B[ i ] [ j ] をそのまま足すのはまずい。区間 [L_i, R_i) の部分区間には iが含まれないものもあるからである。具体的には r\le iまたは i\lt lであるような [l,r)である(図を描いた方がわかりやすそうだが、めんどい)。 そのような「余計に足されてしまう区間」は dp[L_i ] [i ]  dp[i+1 ] [ R_i ] に集約される。なのでこいつらに - B [ i ] [ j ] を足しておけば後で差し引き0になる。

以上で解けた。計算量は O(N^2 + NM)である。計算量は増えるが、 L_i,R_iを求めるところはBIT使うのが自然かもしれない。

ACコード

上とはDPの定義が違って、dp[長さ][左端]にしている。そっちのが累積しやすそうだったからだが、あまり手間は変わらなさそう。

Submission #20276546 - AtCoder Regular Contest 067

def solve(N, M, A, B):
    INF = 10**18
    # dp[length][left]
    dp = [[0]*(N+1-i) for i in range(N+1)]

    for m in range(M):
        # find segment
        Left = [-1]*N
        Right = [-1]*N
        stack = [(INF, 0)]
        for i in range(N):
            b = B[i][m]
            while stack[-1][0] < b:
                stack.pop()
            Left[i] = stack[-1][1]
            stack.append((b, i+1))
        stack = [(INF, N)]
        for i in reversed(range(N)):
            b = B[i][m]
            while stack[-1][0] < b:
                stack.pop()
            Right[i] = stack[-1][1]
            stack.append((b, i))
        
        for i, (l, r) in enumerate(zip(Left, Right)):
            b = B[i][m]
            dp[r-l][l] += b
            dp[i-l][l] -= b
            dp[r-i-1][i+1] -= b
    
    SA = [0]
    for a in A:
        SA.append(SA[-1]+a)
    
    ans = -INF
    for length in reversed(range(1, N+1)):
        for l in range(N+1-length):
            d = dp[length][l]
            ans = max(ans, d - (SA[l+length-1]-SA[l]))
            dp[length-1][l] += d
            dp[length-1][l+1] += d
            if length >= 2:
                dp[length-2][l+1] -= d

    return ans

感想

解説と似てたけどちょっと違ったので書いた。いもす法と言われてみるとそうなのかな。