Hướng dẫn giải của Bệnh viện


Chỉ dùng lời giải này khi không có ý tưởng, và đừng copy-paste code từ lời giải này. Hãy tôn trọng người ra đề và người viết lời giải.
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.

Tác giả: mrtee

Để thuận tiện, ta sẽ coi đây như một cái cây, với ~n~ đỉnh và ~m~ đỉnh đặc biệt (gọi là tập ~P~). Bài toán cần giải là đếm số đỉnh ~u~ sao cho ~max(dist(u,v)) \le k~, ~\forall v \in P~. Với ~dist(u,v)~ là đường đi ngắn nhất từ ~u~ tới ~v~.

Subtask 1,2

  • Dễ thấy với ~n \le 10^4~, ta hoàn toàn có thể ~dfs~ ~n~ lần.
  • Với mỗi đỉnh ~u~, . Gọi ~lv[v]~ là độ dài đường đi ngắn nhất từ ~u~ tới ~v~, ta có thể tính mảng này bằng cách ~dfs~ từ đỉnh ~u~.
  • Ta sẽ kiểm tra xem ~max(lv[v]) \forall v \in P~ có ~ \le k~ hay không, nếu có thì đây là một đỉnh thỏa mãn.
  • Độ phức tạp ~O(n^2)~.

Subtask 3

  • Cây lúc này sẽ suy biến về đường thẳng.
  • Ta có ví dụ sau:

    Imgur

  • Các đỉnh ~2,3, 6~ là các đỉnh đặc biệt. Tuy vậy, nếu coi ~1~ là gốc của cây thì ta thấy rằng, chỉ cần quan tâm tới đỉnh gần ~1~ nhất và đỉnh xa ~1~ nhất là hai đỉnh ~2,3~. Khi ấy khoảng cách xa nhất của một đỉnh tới một đỉnh đặc biệt sẽ là ~max~ khoảng cách từ đỉnh ấy tới đỉnh ~2~ và ~3~.
  • Vậy, để tổng quát, ta làm như sau:
    • Để cây sẽ được đi theo thứ tự đường thẳng, chọn gốc của cây là đỉnh có duy nhất một cạnh kề. Gọi ~x~ là đỉnh gần gốc nhất, ~y~ là đỉnh xa gốc nhất.
    • Gọi ~l[u]~ là ~dist(x,u)~, ~r[u]~ là ~dist(y,u)~. Như vậy chỉ cần kiểm tra xem ~max(l[u],r[u])~ có ~\le k~ hay không.
  • Độ phức tạp ~O(n)~.

Subtask 4

  • Có thể thấy ~subtask~ ~3~ là một gợi ý khá rõ cho ~subtask~ này.
  • Để giải được bài toán này, ta cần hiểu sâu một chút về diameter hay đường đi dài nhất trên cây. Hai đỉnh ~u,v~ là hai đỉnh đầu và cuối của diamter khi và chỉ khi với mọi đỉnh ~x~, ~max(dist(x,u),dist(x,v)) = ~ độ dài đường đi dài nhất của ~x~ tới một đỉnh trong tập.
  • Như vậy, ta cũng sẽ đi tìm diamter, nhưng ở một điều kiện mới, đó là tìm đường đi dài nhất trên cây giữa hai đỉnh đặc biệt, tức ~u,v \in P~.
  • Cách tìm diamter đã khá phổ biến. Đầu tiên ta ~dfs~ từ một đỉnh ~x~ bất kì, thu được đỉnh ~u~ là đỉnh xa với đỉnh ~x~ nhất. Sau đó lại ~dfs~ từ ~u~, thu được đỉnh ~v~ xa với đỉnh ~u~ nhất. Lúc này ~u,v~ là hai đỉnh đầu và cuối của diamter. Tuy nhiên, để tìm đường đi dài nhất giữa hai đỉnh đặc biệt, ta cần xét các ~x,u,v \in P~.
  • Lúc này ta xây dựng hai mảng ~l,r~ và giải tương tự như ~subtask~ ~3~.
  • Độ phức tạp ~O(N)~.

Code C++

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1e5+5;
vector<int> a[N];
int n,m,d;
int p[N];
int st;
int dist1[N],dist2[N],lv[N];
void dfs (int u, int p, int *lv) {
    for (int i=0; i<a[u].size(); i++) {
        int v = a[u][i];
        if (v == p) continue;
        lv[v] = lv[u] + 1;
        dfs (v,u,lv);
    }
}
signed main () {
    freopen("BV.INP", "r", stdin);
    freopen("BV.OUT", "w", stdout);
    cin >> n >> m >> d;
    for (int i=1; i<n; i++) {
        int u,v; cin >> u >> v;
        a[u].push_back(v);
        a[v].push_back(u);
    }
    for (int i=1; i<=m;i++) {
        int x; cin >> x;
        p[x] = true;
        st = x;
    }   
    dfs (1,1,lv);
    st = 0;
    for (int i=1; i<=n; i++) {
        if ((st == 0 || lv[st] < lv[i]) && p[i]) st = i;
    }
    dfs (st,st,dist1);
    int en = 0;
    for (int i=1; i<=n; i++) {
        if ((!en || dist1[en] < dist1[i])&& p[i]) en = i;
    }
    dfs (en,en,dist2);
    int ans = 0;
    for (int i=1; i<=n; i++) {
        if (dist1[i] <= d && dist2[i] <= d ) ans ++;
    }
    cout << ans;
}

Code Python

import sys
from collections import deque

sys.stdin = open("BV.INP", "r")
sys.stdout = open("BV.OUT", "w")

MAXN = 100001

edge = list([] for i in range(MAXN))
s = list(0 for i in range(MAXN))
depth1 = list(0 for i in range(MAXN))
depth2 = list(0 for i in range(MAXN))
depth3 = list(0 for i in range(MAXN))

n, m, d = 0, 0, 0

def bfs(root, depth):
    queue = deque()
    queue.append(root)
    vis = list(0 for i in range(MAXN))

    while len(queue):
        u = queue.pop()
        for v in edge[u]:
            if not vis[v]:
                depth[v] = depth[u] + 1
                vis[v] = 1
                queue.append(v)


def main():
    n, m, d = map(int, input().split())

    for i in range(n - 1):
        u, v = map(int, input().split())
        edge[u].append(v)
        edge[v].append(u)

    for p in list(map(int, input().split())): s[p] = 1

    bfs(1, depth1)
    bfs(max(list(x for x in range(1, n + 1) if s[x]), key = lambda x : depth1[x]), depth2)
    #print(depth2)
    bfs(max(list(x for x in range(1, n + 1) if s[x]), key = lambda x : depth2[x]), depth3)

    print(len(list(i for i in range(1, n + 1) if depth2[i] <= d and depth3[i] <= d)))

if __name__ == '__main__':
    main()