LeetCode 23. 合并 K 个升序链表

给你一个链表数组,每个链表都已经按升序排列。

请你将所有链表合并到一个升序链表中,返回合并后的链表。

示例 :输入:lists = [[1,4,5],[1,3,4],[2,6]] 输出:[1,1,2,3,4,4,5,6] 解释:链表数组如下: [ 1->4->5, 1->3->4, 2->6 ] 将它们合并到一个有序链表中得到。 1->1->2->3->4->4->5->6

C++
#include <iostream>
#include <vector>
#include <algorithm>
#include <limits>

template<typename T>
class LinkedListNode {
public:
    typedef std::shared_ptr<LinkedListNode<T>> NodePtr;
    T value;
    typename NodePtr next;

    LinkedListNode(T value) {
        this->value = value;
        next = nullptr;
    }
};

template<typename T>
class LinkedList {
public:
    typename LinkedListNode<T>::NodePtr head;
	std::weak_ptr<LinkedListNode<T>> tail;
    LinkedList() {
        head = nullptr;
		tail = std::weak_ptr<LinkedListNode<T>>();
    }
    void add(T value) {
        if (head == nullptr) {
			head = std::make_shared<LinkedListNode<T>>(value);
			tail = head;
		}
        else {
            insertAfter((tail.lock()), value);
        }
    }
    void insertAfter(typename LinkedListNode<T>::NodePtr curNode, T value) {
        auto node = std::make_shared<LinkedListNode<T>>(value);
        node->next = curNode->next;
        // if current is tail
        if (curNode->next == nullptr) {
            tail = node;
        }
        curNode->next = node;
    }
    void concat(typename LinkedListNode<T>::NodePtr node) {
        if (tail.lock() == nullptr) {
			head = node;
			tail = node;
		}
        else {
			tail.lock()->next = node;
            tail = node;
		}
    }
    void print() {
        std::cout << "[";
        auto it = head;
        while (it != tail.lock())
        {
            std::cout << it->value;
            std::cout << "->";
            it = it->next;
        }
        if (it != nullptr) {
			std::cout << it->value;
		}
        else {
            std::cout << 'END' ;
		
        }

        std::cout << "]";

    }
};

template<bool BREAK_INPUT>
LinkedList<int> FindInCommon(std::vector<LinkedList<int>> all_inputs) {
    auto rst = LinkedList<int>();
    std::vector<std::shared_ptr<LinkedListNode<int>>> input_its;
    for (auto it : all_inputs) {
		input_its.push_back(it.head);
	}
    while (true)
    {
        int min = std::numeric_limits<int>::max();
        int min_id = -1;
        for (int i = 0; i != input_its.size(); i++) {
            if (input_its[i] != nullptr && input_its[i]->value < min) {
                min = input_its[i]->value;
                min_id = i;
            }
        }
        if (min_id == -1) {
            return rst;
        }
        else {
            if constexpr (BREAK_INPUT) {
                rst.concat(input_its[min_id]);
            }
            else {
                rst.add(input_its[min_id]->value);
            }
            input_its[min_id] = input_its[min_id]->next;
        }
    }
}

const char st = '[';
const char ed = ']';
const char split = ',';
std::vector<LinkedList<int>> parseInput() {
    int depth = 0;
    char token = '_';
    std::vector<LinkedList<int>> all_inputs;

    while (true)
    {
        char next = std::cin.peek();
        if (next == split) {
            std::cin >> token;
            continue;
        }
        if (next == st) {
            if (depth == 1) {
                // make new LinkedList when find '[' at depth 1
                all_inputs.emplace_back(LinkedList<int>());
            }
            depth++;
            std::cin >> token;
            continue;
        }
        if (next == ed) {
            depth--;
            std::cin >> token;
            continue;
        }

        // assume next is dight

        if (depth == 2)
        {
            int value;
            std::cin >> value;
            all_inputs[all_inputs.size() - 1].add(value);
        }
        if (depth == 0) break;
        if (depth == 1) break; // Error
    }

    for (auto list : all_inputs)
    {
        list.print();
        std::cout << std::endl;

    }
    return all_inputs;
}

int main()
{
    auto all_inputs = parseInput();
    auto rst = FindInCommon<false>(all_inputs);
    std::cout << "Rst 0:" << std::endl << '\t';
    rst.print();
    std::cout << std::endl;

    rst = FindInCommon<true>(all_inputs);
    std::cout << "Rst 1:" << std::endl << '\t';
    rst.print();
    std::cout << std::endl;
}

/* 
Inputs: [[1,333,444],[3243,34334,99891],[3,4,4,6,7]]
Outputs:
[1->333->444]
[3243->34334->99891]
[3->4->4->6->7]
Rst 0:
        [1->3->4->4->6->7->333->444->3243->34334->99891]
Rst 1:
        [1->3->4->4->6->7->333->444->3243->34334->99891]
*/

面试的时候出了这道题,想看看我自己还会不会做,要花多久。

结果应该需要30分钟以上,而且C++好多语法都记不住了,还是得练练。

如果不需要设计数据结构和解析输入参数的话,就花不了太多时间了。

原则上,复用原始链表的节点可以避免创建额外的内存。

我这里面用了一大堆智能指针,本身的额外空间也不少了,对于int形的输入来说完全没必要。如果T是更复杂的数据结构应该会有点意义。

发表回复

电子邮件地址不会被公开。必填项已用 * 标注