Editorial for Trung bình cộng


Remember to use this editorial only when stuck, and not to copy-paste code from it. Please be respectful to the problem author and editorialist.
Submitting an official solution before solving the problem yourself is a bannable offence.

Ở subtask trâu ~n \le 10~, các bạn đơn giản là duyệt nhị phân để tìm mọi cách phân chia các đoạn.

Đây là một bài tối ưu QHĐ đơn thuần!

Giải bằng QHĐ như sau:

Gọi ~dp(i,j)~ là xét tới vị trí ~i~ ở dãy ~1~ và ~j~ ở dãy hai thì số cách thỏa mãn để chia là bao nhiêu.

Dễ thấy ~dp(i,j) = \sum dp(x,y)~ thỏa mãn ~average(x+1,i) \le average(y+1,j)~.

Gọi ~Sa~ và ~Sb~ là prefix sum của hai dãy ~a~ và ~b~, lúc đó của ~a~ bằng:

  • ~average(x+1,i,a) = \frac{Sa(i) - Sa(x)}{i-x}~
  • ~average(y+1,j,b) = \frac{Sb(i) - Sb(x)}{j-y}~

Khi ta duyệt ~i~ và đang tính ~dp~ cho lớp ~dp(i,j)~ (hình dung nó như một cái bảng, ta đang duyệt tới hàng ~i~):

  • Ta sẽ sort các giá trị ~\frac{Sa(i) - Sa(x)}{i-x}~ của các ~0 \le x \le i-1~ lại
  • Sau đó khi tính ~dp(i,j)~, công việc của ta sẽ là đi duyệt các giá trị ~y~ để tính ~average(y+1,j,b)~. Lúc này, các giá trị ~x~ thỏa mãn ~average(x+1,i,a) \le average(y+1,j,b)~ sẽ nằm liên tiếp nhau ở phần prefix của từng cột, do nó đã được sort. Nên ta chỉ cần chặt nhị phân để tìm kiếm và dùng prefix sum để cập nhật hàm ~dp~.

Tới đây ta mới thu được thuật toán ~O(n^3 \times log)~. Để bỏ ~log~, ta nhận xét rằng nếu ta duyệt các giá trị ~y~ sao cho ~average(y+1,j,b)~ tăng dần, thì lúc cập nhật giá trị ~dp~, ta không cần tìm kiếm nhị phân nữa mà chỉ cần dùng hai con trỏ. Như vậy ta có thể chuẩn bị trước bằng cách sort ~y~ theo ~average(y+1,j,b)~ rồi bắt đầu duyệt và tính toán.

Độ phức tạp: ~O(n^3)~

Code (fryingduc):

#include "bits/stdc++.h"
using namespace std;

#ifdef duc_debug
#include "bits/debug.h"
#else
#define debug(...)
#endif

const int maxn = 505;
const int mod = 1e9 + 7;
int n, a[maxn];
int b[maxn];
long long pa[maxn], pb[maxn];
int f[maxn][maxn];
int g[maxn][maxn];
vector<int> ordd[maxn];

int add(int x, int y) {
  if(y < 0) y += mod;
  x = x + y;
  if(x >= mod) x -= mod;
  return x;
}
void solve() {
  cin >> n;
  for(int i = 1; i <= n; ++i) {
    cin >> a[i];
    pa[i] = pa[i - 1] + a[i];
  }
  for(int i = 1; i <= n; ++i) {
    cin >> b[i];
    pb[i] = pb[i - 1] + b[i];
  }
  for(int j = 1; j <= n; ++j) {
    ordd[j].resize(j);
    iota(ordd[j].begin(), ordd[j].end(), 0);
    sort(ordd[j].begin(), ordd[j].end(), [&](const int &x, const int &y) -> bool {
      return 1ll * (pb[j] - pb[x]) * (j - y) < 1ll * (pb[j] - pb[y]) * (j - x);
    });
  }
  f[0][0] = 1;
  for(int i = 1; i <= n; ++i) {
    vector<int> ord(i);
    iota(ord.begin(), ord.end(), 0);
    sort(ord.begin(), ord.end(), [&](const int &x, const int &y) -> bool {
      return 1ll * (pa[i] - pa[x]) * (i - y) < 1ll * (pa[i] - pa[y]) * (i - x);
    });
    for(int j = 0; j <= n; ++j) {
      for(int p = 0; p < i; ++p) {
        if(!p) g[j][0] = f[ord[p]][j];
        else g[j][p] = add(g[j][p - 1], f[ord[p]][j]);
      }
    }
    for(int j = 1; j <= n; ++j) {
      int ptr = -1;
      vector<int> &ord_j = ordd[j];
      for(int k = 0; k < j; ++k) {
        while(ptr + 1 < i and 1ll * (pb[j] - pb[ord_j[k]]) * (i - ord[ptr + 1]) >= (pa[i] - pa[ord[ptr + 1]]) * (j - ord_j[k])) {
          ++ptr;
        }
        if(ptr != -1) {
          f[i][j] = add(f[i][j], g[ord_j[k]][ptr]);
        }
      }
    }
  }
  cout << f[n][n];
}
signed main() {
  ios_base::sync_with_stdio(0);
  cin.tie(0);

  solve();

  return 0;
}