2020 Jan Silver Problem 3 Wormhole Sort

From Wiki
Jump to navigation Jump to search

Official Problem Statement[edit]

Wormhole Sort

Problem Statement[edit]

In the USACO problem titled "Wormhole Sort" (ID 992), you are given N cows located at integer points on a number line. There are M wormholes, and each wormhole connects two locations. Each wormhole has a width W, and a cow can only pass through the wormhole if the width is at least as large as the cow's size. Farmer John wants to sort the cows in increasing order of their size using the wormholes, such that each cow can travel from its original position to its final position in the sorted order.

Your task is to find the minimum width W required for all wormholes so that the cows can be sorted. If it is impossible to sort the cows using the wormholes, output -1.

Solution[edit]

A solution to this problem can be found using a binary search on the width W and the Union-Find data structure to track connected components of the graph formed by the wormholes.

code example[edit]

C++[edit]

#include <bits/stdc++.h>
using namespace std;

int N, M;
vector<pair<int, int>> cows;
vector<tuple<int, int, int>> wormholes;
vector<int> parent;

int find(int x) {
    if (parent[x] == x) return x;
    return parent[x] = find(parent[x]);
}

bool merge(int a, int b) {
    a = find(a);
    b = find(b);
    if (a == b) return false;
    parent[a] = b;
    return true;
}

bool can_sort(int min_width) {
    parent.resize(N);
    iota(parent.begin(), parent.end(), 0);

    for (auto &[width, a, b] : wormholes) {
        if (width >= min_width) {
            merge(a, b);
        }
    }

    for (int i = 0; i < N; i++) {
        if (find(i) != find(cows[i].second)) {
            return false;
        }
    }

    return true;
}

int main() {
    cin >> N >> M;
    cows.resize(N);
    wormholes.resize(M);

    for (int i = 0; i < N; i++) {
        cin >> cows[i].first;
        cows[i].second = i;
    }

    for (int i = 0; i < M; i++) {
        int a, b, width;
        cin >> a >> b >> width;
        wormholes[i] = make_tuple(width, a - 1, b - 1);
    }

    sort(cows.begin(), cows.end());

    int low = 1, high = 1e9 + 1;

    while (low < high) {
        int mid = (low + high) / 2;

        if (can_sort(mid)) {
            high = mid;
        } else {
            low = mid + 1;
        }
    }

    if (low == 1e9 + 1) {
        cout << -1 << endl;
    } else {
        cout << low << endl;
    }

    return 0;
}


Java[edit]

import java.util.*;
import java.io.*;

public class WormSort {
    static int N, M;
    static List<Pair> cows;
    static List<Tuple> wormholes;
    static int[] parent;

    public static int find(int x) {
        if (parent[x] == x) return x;
        return parent[x] = find(parent[x]);
    }

    public static void merge(int a, int b) {
        parent[find(a)] = find(b);
    }

    public static void main(String[] args) throws IOException {
        BufferedReader in = new BufferedReader(new FileReader("wormsort.in"));
        PrintWriter out = new PrintWriter(new FileWriter("wormsort.out"));
        StringTokenizer st = new StringTokenizer(in.readLine());
        N = Integer.parseInt(st.nextToken());
        M = Integer.parseInt(st.nextToken());
        cows = new ArrayList<>(N);
        wormholes = new ArrayList<>(M);

        st = new StringTokenizer(in.readLine());
        for (int i = 0; i < N; i++) {
            int cowPos = Integer.parseInt(st.nextToken()) - 1;
            cows.add(new Pair(cowPos, i));
        }

        for (int i = 0; i < M; i++) {
            st = new StringTokenizer(in.readLine());
            int a = Integer.parseInt(st.nextToken()) - 1;
            int b = Integer.parseInt(st.nextToken()) - 1;
            int width = Integer.parseInt(st.nextToken());
            wormholes.add(new Tuple(width, a, b));
        }

        Collections.sort(wormholes);
        parent = new int[N];
        for (int i = 0; i < N; i++) {
            parent[i] = i;
        }

        int index = M;
        int width = -1;
        for (Pair cow : cows) {
            while(find(cow.first) != find(cow.second)) {
                index--;
                if (index < 0) {
                    out.println(-1);
                    out.close();
                    return;
                }
                Tuple x = wormholes.get(index);
                width = x.width;
                merge(x.a, x.b);
            }
        }

        out.println(width);
        out.close();
    }

    static class Pair {
        int first, second;

        public Pair(int first, int second) {
            this.first = first;
            this.second = second;
        }
    }

    static class Tuple implements Comparable<Tuple> {
        int width, a, b;

        public Tuple(int width, int a, int b) {
            this.width = width;
            this.a = a;
            this.b = b;
        }

        public int compareTo(Tuple other) {
            return Integer.compare(this.width, other.width);
        }
    }
}

Python[edit]

from sys import stdin
from operator import attrgetter

class Pair:
    def __init__(self, first, second):
        self.first = first
        self.second = second

class Tuple:
    def __init__(self, width, a, b):
        self.width = width
        self.a = a
        self.b = b

def find(x):
    if parent[x] == x:
        return x
    parent[x] = find(parent[x])
    return parent[x]

def merge(a, b):
    parent[find(a)] = find(b)

with open('wormsort.in', 'r') as fin, open('wormsort.out', 'w') as fout:
    N, M = map(int, fin.readline().split())
    cows = [Pair(int(x) - 1, i) for i, x in enumerate(fin.readline().split())]
    wormholes = [Tuple(*map(int, line.split())) for line in fin.readlines()]

    wormholes.sort(key=attrgetter('width'), reverse=True)
    parent = list(range(N))

    index = M
    width = -1
    for cow in cows:
        while find(cow.first) != find(cow.second):
            index -= 1
            if index < 0:
                fout.write(str(-1) + '\n')
                exit(0)
            x = wormholes[index]
            width = x.width
            merge(x.a, x.b)

    fout.write(str(width) + '\n')