Hướng dẫn giải của Trung bình cộng


Chỉ dùng lời giải này khi không có ý tưởng, và đừng copy-paste code từ lời giải này. Hãy tôn trọng người ra đề và người viết lời giải.
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.

Ở subtask trâu ~n \le 10~, các bạn đơn giản là duyệt nhị phân để tìm mọi cách phân chia các đoạn.

Đây là một bài tối ưu QHĐ đơn thuần!

Giải bằng QHĐ như sau:

Gọi ~dp(i,j)~ là xét tới vị trí ~i~ ở dãy ~1~ và ~j~ ở dãy hai thì số cách thỏa mãn để chia là bao nhiêu.

Dễ thấy ~dp(i,j) = \sum dp(x,y)~ thỏa mãn ~average(x+1,i) \le average(y+1,j)~.

Gọi ~Sa~ và ~Sb~ là prefix sum của hai dãy ~a~ và ~b~, lúc đó của ~a~ bằng:

  • ~average(x+1,i,a) = \frac{Sa(i) - Sa(x)}{i-x}~
  • ~average(y+1,j,b) = \frac{Sb(i) - Sb(x)}{j-y}~

Khi ta duyệt ~i~ và đang tính ~dp~ cho lớp ~dp(i,j)~ (hình dung nó như một cái bảng, ta đang duyệt tới hàng ~i~):

  • Ta sẽ sort các giá trị ~\frac{Sa(i) - Sa(x)}{i-x}~ của các ~0 \le x \le i-1~ lại
  • Sau đó khi tính ~dp(i,j)~, công việc của ta sẽ là đi duyệt các giá trị ~y~ để tính ~average(y+1,j,b)~. Lúc này, các giá trị ~x~ thỏa mãn ~average(x+1,i,a) \le average(y+1,j,b)~ sẽ nằm liên tiếp nhau ở phần prefix của từng cột, do nó đã được sort. Nên ta chỉ cần chặt nhị phân để tìm kiếm và dùng prefix sum để cập nhật hàm ~dp~.

Tới đây ta mới thu được thuật toán ~O(n^3 \times log)~. Để bỏ ~log~, ta nhận xét rằng nếu ta duyệt các giá trị ~y~ sao cho ~average(y+1,j,b)~ tăng dần, thì lúc cập nhật giá trị ~dp~, ta không cần tìm kiếm nhị phân nữa mà chỉ cần dùng hai con trỏ. Như vậy ta có thể chuẩn bị trước bằng cách sort ~y~ theo ~average(y+1,j,b)~ rồi bắt đầu duyệt và tính toán.

Độ phức tạp: ~O(n^3)~

Code (fryingduc):

#include "bits/stdc++.h"
using namespace std;

#ifdef duc_debug
#include "bits/debug.h"
#else
#define debug(...)
#endif

const int maxn = 505;
const int mod = 1e9 + 7;
int n, a[maxn];
int b[maxn];
long long pa[maxn], pb[maxn];
int f[maxn][maxn];
int g[maxn][maxn];
vector<int> ordd[maxn];

int add(int x, int y) {
  if(y < 0) y += mod;
  x = x + y;
  if(x >= mod) x -= mod;
  return x;
}
void solve() {
  cin >> n;
  for(int i = 1; i <= n; ++i) {
    cin >> a[i];
    pa[i] = pa[i - 1] + a[i];
  }
  for(int i = 1; i <= n; ++i) {
    cin >> b[i];
    pb[i] = pb[i - 1] + b[i];
  }
  for(int j = 1; j <= n; ++j) {
    ordd[j].resize(j);
    iota(ordd[j].begin(), ordd[j].end(), 0);
    sort(ordd[j].begin(), ordd[j].end(), [&](const int &x, const int &y) -> bool {
      return 1ll * (pb[j] - pb[x]) * (j - y) < 1ll * (pb[j] - pb[y]) * (j - x);
    });
  }
  f[0][0] = 1;
  for(int i = 1; i <= n; ++i) {
    vector<int> ord(i);
    iota(ord.begin(), ord.end(), 0);
    sort(ord.begin(), ord.end(), [&](const int &x, const int &y) -> bool {
      return 1ll * (pa[i] - pa[x]) * (i - y) < 1ll * (pa[i] - pa[y]) * (i - x);
    });
    for(int j = 0; j <= n; ++j) {
      for(int p = 0; p < i; ++p) {
        if(!p) g[j][0] = f[ord[p]][j];
        else g[j][p] = add(g[j][p - 1], f[ord[p]][j]);
      }
    }
    for(int j = 1; j <= n; ++j) {
      int ptr = -1;
      vector<int> &ord_j = ordd[j];
      for(int k = 0; k < j; ++k) {
        while(ptr + 1 < i and 1ll * (pb[j] - pb[ord_j[k]]) * (i - ord[ptr + 1]) >= (pa[i] - pa[ord[ptr + 1]]) * (j - ord_j[k])) {
          ++ptr;
        }
        if(ptr != -1) {
          f[i][j] = add(f[i][j], g[ord_j[k]][ptr]);
        }
      }
    }
  }
  cout << f[n][n];
}
signed main() {
  ios_base::sync_with_stdio(0);
  cin.tie(0);

  solve();

  return 0;
}