LeetCode Solutions

148. Sort List

Time: $O(n\log n)$

Space: $O(1)$

			

class Solution {
 public:
  ListNode* sortList(ListNode* head) {
    const int length = getLength(head);
    ListNode dummy(0, head);

    for (int k = 1; k < length; k *= 2) {
      ListNode* curr = dummy.next;
      ListNode* tail = &dummy;
      while (curr) {
        ListNode* l = curr;
        ListNode* r = split(l, k);
        curr = split(r, k);
        auto [mergedHead, mergedTail] = merge(l, r);
        tail->next = mergedHead;
        tail = mergedTail;
      }
    }

    return dummy.next;
  }

 private:
  int getLength(ListNode* head) {
    int length = 0;
    for (ListNode* curr = head; curr; curr = curr->next)
      ++length;
    return length;
  }

  ListNode* split(ListNode* head, int k) {
    while (--k && head)
      head = head->next;
    ListNode* rest = head ? head->next : nullptr;
    if (head != nullptr)
      head->next = nullptr;
    return rest;
  }

  pair<ListNode*, ListNode*> merge(ListNode* l1, ListNode* l2) {
    ListNode dummy(0);
    ListNode* tail = &dummy;

    while (l1 && l2) {
      if (l1->val > l2->val)
        swap(l1, l2);
      tail->next = l1;
      l1 = l1->next;
      tail = tail->next;
    }
    tail->next = l1 ? l1 : l2;
    while (tail->next)
      tail = tail->next;

    return {dummy.next, tail};
  }
};
			

class Solution {
  public ListNode sortList(ListNode head) {
    final int length = getLength(head);
    ListNode dummy = new ListNode(0, head);

    for (int k = 1; k < length; k *= 2) {
      ListNode curr = dummy.next;
      ListNode tail = dummy;
      while (curr != null) {
        ListNode l = curr;
        ListNode r = split(l, k);
        curr = split(r, k);
        ListNode[] merged = merge(l, r);
        tail.next = merged[0];
        tail = merged[1];
      }
    }

    return dummy.next;
  }

  private int getLength(ListNode head) {
    int length = 0;
    for (ListNode curr = head; curr != null; curr = curr.next)
      ++length;
    return length;
  }

  private ListNode split(ListNode head, int k) {
    while (--k > 0 && head != null)
      head = head.next;
    ListNode rest = head == null ? null : head.next;
    if (head != null)
      head.next = null;
    return rest;
  }

  private ListNode[] merge(ListNode l1, ListNode l2) {
    ListNode dummy = new ListNode(0);
    ListNode tail = dummy;

    while (l1 != null && l2 != null) {
      if (l1.val > l2.val) {
        ListNode temp = l1;
        l1 = l2;
        l2 = temp;
      }
      tail.next = l1;
      l1 = l1.next;
      tail = tail.next;
    }
    tail.next = l1 == null ? l2 : l1;
    while (tail.next != null)
      tail = tail.next;

    return new ListNode[] {dummy.next, tail};
  }
}
			

class Solution:
  def sortList(self, head: ListNode) -> ListNode:
    def split(head: ListNode, k: int) -> ListNode:
      while k > 1 and head:
        head = head.next
        k -= 1
      rest = head.next if head else None
      if head:
        head.next = None
      return rest

    def merge(l1: ListNode, l2: ListNode) -> tuple:
      dummy = ListNode(0)
      tail = dummy

      while l1 and l2:
        if l1.val > l2.val:
          l1, l2 = l2, l1
        tail.next = l1
        l1 = l1.next
        tail = tail.next
      tail.next = l1 if l1 else l2
      while tail.next:
        tail = tail.next

      return dummy.next, tail

    length = 0
    curr = head
    while curr:
      length += 1
      curr = curr.next

    dummy = ListNode(0, head)

    k = 1
    while k < length:
      curr = dummy.next
      tail = dummy
      while curr:
        l = curr
        r = split(l, k)
        curr = split(r, k)
        mergedHead, mergedTail = merge(l, r)
        tail.next = mergedHead
        tail = mergedTail
      k *= 2

    return dummy.next