LeetCode Solutions
148. Sort List
Time: $O(n\log n)$ Space: $O(1)$
class Solution {
public:
ListNode* sortList(ListNode* head) {
const int length = getLength(head);
ListNode dummy(0, head);
for (int k = 1; k < length; k *= 2) {
ListNode* curr = dummy.next;
ListNode* tail = &dummy;
while (curr) {
ListNode* l = curr;
ListNode* r = split(l, k);
curr = split(r, k);
auto [mergedHead, mergedTail] = merge(l, r);
tail->next = mergedHead;
tail = mergedTail;
}
}
return dummy.next;
}
private:
int getLength(ListNode* head) {
int length = 0;
for (ListNode* curr = head; curr; curr = curr->next)
++length;
return length;
}
ListNode* split(ListNode* head, int k) {
while (--k && head)
head = head->next;
ListNode* rest = head ? head->next : nullptr;
if (head != nullptr)
head->next = nullptr;
return rest;
}
pair<ListNode*, ListNode*> merge(ListNode* l1, ListNode* l2) {
ListNode dummy(0);
ListNode* tail = &dummy;
while (l1 && l2) {
if (l1->val > l2->val)
swap(l1, l2);
tail->next = l1;
l1 = l1->next;
tail = tail->next;
}
tail->next = l1 ? l1 : l2;
while (tail->next)
tail = tail->next;
return {dummy.next, tail};
}
};
class Solution {
public ListNode sortList(ListNode head) {
final int length = getLength(head);
ListNode dummy = new ListNode(0, head);
for (int k = 1; k < length; k *= 2) {
ListNode curr = dummy.next;
ListNode tail = dummy;
while (curr != null) {
ListNode l = curr;
ListNode r = split(l, k);
curr = split(r, k);
ListNode[] merged = merge(l, r);
tail.next = merged[0];
tail = merged[1];
}
}
return dummy.next;
}
private int getLength(ListNode head) {
int length = 0;
for (ListNode curr = head; curr != null; curr = curr.next)
++length;
return length;
}
private ListNode split(ListNode head, int k) {
while (--k > 0 && head != null)
head = head.next;
ListNode rest = head == null ? null : head.next;
if (head != null)
head.next = null;
return rest;
}
private ListNode[] merge(ListNode l1, ListNode l2) {
ListNode dummy = new ListNode(0);
ListNode tail = dummy;
while (l1 != null && l2 != null) {
if (l1.val > l2.val) {
ListNode temp = l1;
l1 = l2;
l2 = temp;
}
tail.next = l1;
l1 = l1.next;
tail = tail.next;
}
tail.next = l1 == null ? l2 : l1;
while (tail.next != null)
tail = tail.next;
return new ListNode[] {dummy.next, tail};
}
}
class Solution:
def sortList(self, head: ListNode) -> ListNode:
def split(head: ListNode, k: int) -> ListNode:
while k > 1 and head:
head = head.next
k -= 1
rest = head.next if head else None
if head:
head.next = None
return rest
def merge(l1: ListNode, l2: ListNode) -> tuple:
dummy = ListNode(0)
tail = dummy
while l1 and l2:
if l1.val > l2.val:
l1, l2 = l2, l1
tail.next = l1
l1 = l1.next
tail = tail.next
tail.next = l1 if l1 else l2
while tail.next:
tail = tail.next
return dummy.next, tail
length = 0
curr = head
while curr:
length += 1
curr = curr.next
dummy = ListNode(0, head)
k = 1
while k < length:
curr = dummy.next
tail = dummy
while curr:
l = curr
r = split(l, k)
curr = split(r, k)
mergedHead, mergedTail = merge(l, r)
tail.next = mergedHead
tail = mergedTail
k *= 2
return dummy.next