STNODE - VOI09 Nút st - xung yếu

Tác giả: skyvn97

Ngôn ngữ: C++

#ifndef graph_h
#define graph_h

int calculate(int N, int M, int S, int T, int U[], int V[]);

#endif

#include<bits/stdc++.h>
#define MAX   100100
#define FOR(i,a,b) for (int i=(a),_b=(b);i<=_b;i=i+1)
#define REP(i,n) for (int i=0,_n=(n);i<_n;i=i+1)
#define FORE(i,v) for (__typeof((v).begin()) i=(v).begin();i!=(v).end();i++)
#define ALL(v) (v).begin(),(v).end()
using namespace std;
class SegmentTree {
private:
    int n;
    vector<bool> lazy;
    void update(int i,int l,int r,int u,int v) {
        if (l>v || r<u || l>r || v<u) return;
        if (u<=l && r<=v) {
            lazy[i]=true;
            return;
        }
        int m=(l+r)>>1;
        update(2*i,l,m,u,v);
        update(2*i+1,m+1,r,u,v);
    }
public:
    SegmentTree() {
        n=0;
    }
    SegmentTree(int n) {
        this->n=n;
        lazy.assign(4*n+7,false);
    }
    void update(int l,int r) {
        update(1,1,n,l,r);
    }
    bool get(int x) const {
        int i=1;
        int l=1;
        int r=n;
        while (true) {
            if (lazy[i]) return (true);
            if (l==r) return (false);
            int m=(l+r)>>1;
            if (x>m) {
                i=2*i+1;
                l=m+1;
            } else {
                i=2*i;
                r=m;
            }
        }
    }
};
vector<int> adj[MAX];
vector<int> path;
int trace[MAX];
int pathPos[MAX];
int n,m,s,t,nComp,cnt;
int low[MAX],num[MAX],compID[MAX];
vector<int> comp[MAX];
stack<int> st;
int storedMaxNode[MAX];
bool bfs(void) {
    memset(trace,-1,sizeof trace);
    queue<int> q;
    trace[s]=0;
    q.push(s);
    while (!q.empty()) {
        int u=q.front();q.pop();
        FORE(it,adj[u]) if (trace[*it]<0) {
            int v=*it;
            trace[v]=u;
            q.push(v);
        }
    }
    return (trace[t]>0);
}
void findPath(void) {
    memset(pathPos,-1,sizeof pathPos);
    for (int u=t;u!=s;u=trace[u]) path.push_back(u);
    path.push_back(s);
    reverse(ALL(path));
    REP(i,path.size()) pathPos[path[i]]=i;
}
void dfs(int u) {
    low[u]=num[u]=++cnt;
    st.push(u);
    FORE(it,adj[u]) if (pathPos[*it]<0 && compID[*it]==0) {
        int v=*it;
        if (num[v]==0) {
            dfs(v);
            low[u]=min(low[u],low[v]);
        } else low[u]=min(low[u],num[v]);
    }
    if (low[u]==num[u]) {
        nComp++;
        int v;
        do {
            v=st.top();st.pop();
            compID[v]=nComp;
            comp[nComp].push_back(v);
        } while (v!=u);
    }
}
void tarjan(void) {
    FOR(i,1,n) if (pathPos[i]<0 && num[i]==0) dfs(i);
    memset(storedMaxNode,-0x3f,sizeof storedMaxNode);
}
int maxNode(int u) {
    if (storedMaxNode[u]>=-1) return (storedMaxNode[u]);
    int &res=storedMaxNode[u];
    res=-1;
    FORE(it,comp[u]) FORE(jt,adj[*it]) {
        if (pathPos[*jt]>=0) res=max(res,pathPos[*jt]);
        else if (compID[*jt]!=u) res=max(res,maxNode(compID[*jt]));
    }
    return (res);
}
int maxPathNode(int u) {
    int res=-1;
    FORE(it,adj[u]) if (pathPos[*it]<0) res=max(res,maxNode(compID[*it]));
    return (res);
}
int process(void) {
    SegmentTree myit((int)path.size()-2);
    REP(i,path.size()) {
        int j=maxPathNode(path[i]);
        if (i+1<=j-1) myit.update(i+1,j-1);
        //printf("From %d to %d\n",i,j);
    }
    int res=0;
    FOR(i,1,(int)path.size()-2) if (!myit.get(i)) res++;
    return (res);
}
int calculate(int N, int M, int S, int T, int U[], int V[]) {
    n=N;m=M;s=S+1;t=T+1;
    if (s==t) return (0);
    REP(i,m) {
        int u=U[i]+1;
        int v=V[i]+1;
        adj[u].push_back(v);
        if (u==s && v==t) return (0);
    }
	if (!bfs()) return (0);
	findPath();
	//REP(i,path.size()) printf("%d ",path[i]); printf("\n");
	tarjan();
	/*FOR(i,1,nComp) {
	    printf("Comp #%d:",i);
	    FORE(it,comp[i]) printf(" %d",*it);
	    printf("\n");
	}*/
	return (process());
}

#include <stdio.h>

int main() {
	//freopen("graph.in","r",stdin);
	//freopen("graph.out","w",stdout);

	int N,M,S,T; scanf("%d%d%d%d",&N,&M,&S,&T);
    S--;T--;
	int *U = new int[M];
	int *V = new int[M];

	for (int i = 0; i < M; i++) {
		scanf("%d%d",&U[i],&V[i]);
		U[i]--;V[i]--;
	}

	printf("%d\n",calculate(N,M,S,T,U,V));

	return 0;
}

Download