From 6643673f28b9273149fc945849a13ed832e9ef33 Mon Sep 17 00:00:00 2001
From: bunnei <bunneidev@gmail.com>
Date: Sun, 18 Jan 2015 01:27:46 -0500
Subject: [PATCH] WaitSynchronizationN: Refactor to fix several bugs

- Separate wait checking from waiting the current thread
- Resume thread when wait_all=true only if all objects are available at once
- Set output to correct wait object index when there are duplicate handles
---
 src/core/hle/kernel/event.cpp     |  6 +--
 src/core/hle/kernel/kernel.h      | 12 +++---
 src/core/hle/kernel/mutex.cpp     |  8 ++--
 src/core/hle/kernel/semaphore.cpp |  6 +--
 src/core/hle/kernel/thread.cpp    | 61 ++++++++++++++-----------------
 src/core/hle/kernel/thread.h      |  4 +-
 src/core/hle/kernel/timer.cpp     |  6 +--
 src/core/hle/svc.cpp              | 50 +++++++++++++------------
 8 files changed, 75 insertions(+), 78 deletions(-)

diff --git a/src/core/hle/kernel/event.cpp b/src/core/hle/kernel/event.cpp
index 41e1bd6c5..ae9b06b84 100644
--- a/src/core/hle/kernel/event.cpp
+++ b/src/core/hle/kernel/event.cpp
@@ -28,11 +28,11 @@ public:
     bool signaled;                          ///< Whether the event has already been signaled
     std::string name;                       ///< Name of event (optional)
 
-    ResultVal<bool> Wait(unsigned index) override {
+    ResultVal<bool> Wait(bool wait_thread) override {
         bool wait = !signaled;
-        if (wait) {
+        if (wait && wait_thread) {
             AddWaitingThread(GetCurrentThread());
-            Kernel::WaitCurrentThread_WaitSynchronization(WAITTYPE_EVENT, this, index);
+            Kernel::WaitCurrentThread_WaitSynchronization(WAITTYPE_EVENT, this);
         }
         return MakeResult<bool>(wait);
     }
diff --git a/src/core/hle/kernel/kernel.h b/src/core/hle/kernel/kernel.h
index d98fd0389..cfaf0c901 100644
--- a/src/core/hle/kernel/kernel.h
+++ b/src/core/hle/kernel/kernel.h
@@ -65,18 +65,18 @@ public:
     virtual Kernel::HandleType GetHandleType() const = 0;
 
     /**
-     * Wait the current thread for kernel object to synchronize.
-     * @param index Index of wait object (only applies to WaitSynchronizationN)
-     * @return True if the current thread should wait as a result of the wait
+     * Check if this object is available, (optionally) wait the current thread if not
+     * @param wait_thread If true, wait the current thread if this object is unavailable
+     * @return True if the current thread should wait due to this object being unavailable
      */
-    virtual ResultVal<bool> Wait(unsigned index = 0) {
+    virtual ResultVal<bool> Wait(bool wait_thread) {
         LOG_ERROR(Kernel, "(UNIMPLEMENTED)");
         return UnimplementedFunction(ErrorModule::Kernel);
     }
 
     /**
-     * Acquire/lock the kernel object if it is available
-     * @return True if we were able to acquire the kernel object, otherwise false
+     * Acquire/lock the this object if it is available
+     * @return True if we were able to acquire this object, otherwise false
      */
     virtual ResultVal<bool> Acquire() {
         LOG_ERROR(Kernel, "(UNIMPLEMENTED)");
diff --git a/src/core/hle/kernel/mutex.cpp b/src/core/hle/kernel/mutex.cpp
index 37e7be4e7..f97c69a78 100644
--- a/src/core/hle/kernel/mutex.cpp
+++ b/src/core/hle/kernel/mutex.cpp
@@ -26,7 +26,7 @@ public:
     Handle lock_thread;                         ///< Handle to thread that currently has mutex
     std::string name;                           ///< Name of mutex (optional)
 
-    ResultVal<bool> Wait(unsigned index) override;
+    ResultVal<bool> Wait(bool wait_thread) override;
     ResultVal<bool> Acquire() override;
 };
 
@@ -156,10 +156,10 @@ Handle CreateMutex(bool initial_locked, const std::string& name) {
     return handle;
 }
 
-ResultVal<bool> Mutex::Wait(unsigned index) {
-    if (locked) {
+ResultVal<bool> Mutex::Wait(bool wait_thread) {
+    if (locked && wait_thread) {
         AddWaitingThread(GetCurrentThread());
-        Kernel::WaitCurrentThread_WaitSynchronization(WAITTYPE_MUTEX, this, index);
+        Kernel::WaitCurrentThread_WaitSynchronization(WAITTYPE_MUTEX, this);
     }
 
     return MakeResult<bool>(locked);
diff --git a/src/core/hle/kernel/semaphore.cpp b/src/core/hle/kernel/semaphore.cpp
index 6464b2580..42b5cf704 100644
--- a/src/core/hle/kernel/semaphore.cpp
+++ b/src/core/hle/kernel/semaphore.cpp
@@ -32,11 +32,11 @@ public:
         return available_count > 0;
     }
 
-    ResultVal<bool> Wait(unsigned index) override {
+    ResultVal<bool> Wait(bool wait_thread) override {
         bool wait = !IsAvailable();
 
-        if (wait) {
-            Kernel::WaitCurrentThread_WaitSynchronization(WAITTYPE_SEMA, this, index);
+        if (wait && wait_thread) {
+            Kernel::WaitCurrentThread_WaitSynchronization(WAITTYPE_SEMA, this);
             AddWaitingThread(GetCurrentThread());
         }
 
diff --git a/src/core/hle/kernel/thread.cpp b/src/core/hle/kernel/thread.cpp
index 6b0bdebb5..62b85f56a 100644
--- a/src/core/hle/kernel/thread.cpp
+++ b/src/core/hle/kernel/thread.cpp
@@ -22,11 +22,11 @@
 
 namespace Kernel {
 
-ResultVal<bool> Thread::Wait(unsigned index) {
+ResultVal<bool> Thread::Wait(bool wait_thread) {
     const bool wait = status != THREADSTATUS_DORMANT;
-    if (wait) {
+    if (wait && wait_thread) {
         AddWaitingThread(GetCurrentThread());
-        WaitCurrentThread_WaitSynchronization(WAITTYPE_THREADEND, this, index);
+        WaitCurrentThread_WaitSynchronization(WAITTYPE_THREADEND, this);
     }
 
     return MakeResult<bool>(wait);
@@ -97,7 +97,7 @@ static bool CheckWaitType(const Thread* thread, WaitType type) {
 /// Check if a thread is blocking on a specified wait type with a specified handle
 static bool CheckWaitType(const Thread* thread, WaitType type, Object* wait_object) {
     for (auto itr = thread->wait_objects.begin(); itr != thread->wait_objects.end(); ++itr) {
-        if (itr->first == wait_object)
+        if (*itr == wait_object)
             return CheckWaitType(thread, type);
     }
     return false;
@@ -234,16 +234,7 @@ void WaitCurrentThread_WaitSynchronization(WaitType wait_type, WaitObject* wait_
     Thread* thread = GetCurrentThread();
     thread->wait_type = wait_type;
 
-    bool insert_wait_object = true;
-    for (auto itr = thread->wait_objects.begin(); itr < thread->wait_objects.end(); ++itr) {
-        if (itr->first == wait_object) {
-            insert_wait_object = false;
-            break;
-        }
-    }
-
-    if (insert_wait_object)
-        thread->wait_objects.push_back(std::pair<SharedPtr<WaitObject>, unsigned>(wait_object, index));
+    thread->wait_objects.push_back(wait_object);
 
     ChangeThreadState(thread, ThreadStatus(THREADSTATUS_WAIT | (thread->status & THREADSTATUS_SUSPEND)));
 }
@@ -288,31 +279,35 @@ void Thread::ReleaseFromWait(WaitObject* wait_object) {
         return;
     }
 
-    // Remove this thread from the wait_object
+    // Remove this thread from the waiting object's thread list
     wait_object->RemoveWaitingThread(this);
 
-    // Find the waiting object
-    auto itr = wait_objects.begin();
-    for (; itr != wait_objects.end(); ++itr) {
-        if (wait_object == itr->first)
-            break;
+    unsigned index = 0;
+    bool wait_all_failed = false; // Will be set to true if any object is unavailable
+
+    // Iterate through all waiting objects to check availability...
+    for (auto itr = wait_objects.begin(); itr != wait_objects.end(); ++itr) {
+        auto res = (*itr)->Wait(false);
+
+        if (*res && res.Succeeded())
+            wait_all_failed = true;
+
+        // The output should be the last index of wait_object
+        if (*itr == wait_object)
+            index = itr - wait_objects.begin();
     }
-    unsigned index = itr->second;
 
-    // Remove the wait_object from this thread
-    if (itr != wait_objects.end())
-        wait_objects.erase(itr);
-
-    // If wait_all=false, resume the thread on a release wait_object from wait
-    if (!wait_all) {
-        SetReturnValue(RESULT_SUCCESS, index);
-        ResumeFromWait();
-    } else {
-        // Otherwise, wait_all=true, only resume the thread if all wait_object's have been released
-        if (wait_objects.empty()) {
+    // If we are waiting on all objects...
+    if (wait_all) {
+        // Resume the thread only if all are available...
+        if (!wait_all_failed) {
             SetReturnValue(RESULT_SUCCESS, -1);
             ResumeFromWait();
         }
+    } else {
+        // Otherwise, resume
+        SetReturnValue(RESULT_SUCCESS, index);
+        ResumeFromWait();
     }
 }
 
@@ -324,7 +319,7 @@ void Thread::ResumeFromWait() {
 
     // Remove this thread from all other WaitObjects
     for (auto wait_object : wait_objects)
-        wait_object.first->RemoveWaitingThread(this);
+        wait_object->RemoveWaitingThread(this);
 
     wait_objects.clear();
 
diff --git a/src/core/hle/kernel/thread.h b/src/core/hle/kernel/thread.h
index 9faf89c15..dff6bbaec 100644
--- a/src/core/hle/kernel/thread.h
+++ b/src/core/hle/kernel/thread.h
@@ -70,7 +70,7 @@ public:
     inline bool IsSuspended() const { return (status & THREADSTATUS_SUSPEND) != 0; }
     inline bool IsIdle() const { return idle; }
 
-    ResultVal<bool> Wait(unsigned index) override;
+    ResultVal<bool> Wait(bool wait_thread) override;
     ResultVal<bool> Acquire() override;
 
     s32 GetPriority() const { return current_priority; }
@@ -117,7 +117,7 @@ public:
     s32 processor_id;
 
     WaitType wait_type;
-    std::vector<std::pair<SharedPtr<WaitObject>, unsigned>> wait_objects;
+    std::vector<SharedPtr<WaitObject>> wait_objects;
     VAddr wait_address;
 
     std::string name;
diff --git a/src/core/hle/kernel/timer.cpp b/src/core/hle/kernel/timer.cpp
index 6497bb349..9f0fbafe2 100644
--- a/src/core/hle/kernel/timer.cpp
+++ b/src/core/hle/kernel/timer.cpp
@@ -29,11 +29,11 @@ public:
     u64 initial_delay;                      ///< The delay until the timer fires for the first time
     u64 interval_delay;                     ///< The delay until the timer fires after the first time
 
-    ResultVal<bool> Wait(unsigned index) override {
+    ResultVal<bool> Wait(bool wait_thread) override {
         bool wait = !signaled;
-        if (wait) {
+        if (wait && wait_thread) {
             AddWaitingThread(GetCurrentThread());
-            Kernel::WaitCurrentThread_WaitSynchronization(WAITTYPE_TIMER, this, index);
+            Kernel::WaitCurrentThread_WaitSynchronization(WAITTYPE_TIMER, this);
         }
         return MakeResult<bool>(wait);
     }
diff --git a/src/core/hle/svc.cpp b/src/core/hle/svc.cpp
index a27aa6269..059451100 100644
--- a/src/core/hle/svc.cpp
+++ b/src/core/hle/svc.cpp
@@ -127,7 +127,7 @@ static Result WaitSynchronization1(Handle handle, s64 nano_seconds) {
     LOG_TRACE(Kernel_SVC, "called handle=0x%08X(%s:%s), nanoseconds=%lld", handle,
             object->GetTypeName().c_str(), object->GetName().c_str(), nano_seconds);
 
-    ResultVal<bool> wait = object->Wait();
+    ResultVal<bool> wait = object->Wait(true);
 
     // Check for next thread to schedule
     if (wait.Succeeded() && *wait) {
@@ -146,8 +146,7 @@ static Result WaitSynchronization1(Handle handle, s64 nano_seconds) {
 
 /// Wait for the given handles to synchronize, timeout after the specified nanoseconds
 static Result WaitSynchronizationN(s32* out, Handle* handles, s32 handle_count, bool wait_all, s64 nano_seconds) {
-    bool wait_thread = false;
-    bool wait_all_succeeded = false;
+    bool wait_thread = !wait_all;
     int handle_index = 0;
 
     // Handles pointer is invalid
@@ -158,40 +157,43 @@ static Result WaitSynchronizationN(s32* out, Handle* handles, s32 handle_count,
     if (handle_count < 0)
         return ResultCode(ErrorDescription::OutOfRange, ErrorModule::OS, ErrorSummary::InvalidArgument, ErrorLevel::Usage).raw;
 
-    // If handle_count is non-zero, iterate through them and wait/acquire the objects as needed
+    // If handle_count is non-zero, iterate through them and wait the current thread on the objects
     if (handle_count != 0) {
-        while (handle_index < handle_count) {
-            SharedPtr<Kernel::Object> object = Kernel::g_handle_table.GetGeneric(handles[handle_index]);
+        bool selected = false; // True once an object has been selected
+        for (int i = 0; i < handle_count; ++i) {
+            SharedPtr<Kernel::Object> object = Kernel::g_handle_table.GetGeneric(handles[i]);
             if (object == nullptr)
                 return InvalidHandle(ErrorModule::Kernel).raw;
 
-            ResultVal<bool> wait = object->Wait(handle_index);
+            ResultVal<bool> wait = object->Wait(true);
 
-            wait_thread = (wait.Succeeded() && *wait);
-
-            // If this object waited and we are waiting on all objects to synchronize
-            if (wait_thread && wait_all)
-                // Enforce later on that this thread does not continue
-                wait_all_succeeded = true;
-
-            // If this object synchronized and we are not waiting on all objects to synchronize
-            if (!wait_thread && !wait_all)
-                // We're done, the thread will continue
-                break;
-
-            handle_index++;
+            // Check if the current thread should wait on the object...
+            if (wait.Succeeded() && *wait) {
+                // Check we are waiting on all objects...
+                if (wait_all)
+                    // Wait the thread
+                    wait_thread = true;
+            } else {
+                // Do not wait on this object, check if this object should be selected...
+                if (!wait_all && !selected) {
+                    // Do not wait the thread
+                    wait_thread = false;
+                    handle_index = i;
+                    selected = true;
+                }
+            }
         }
     } else {
         // If no handles were passed in, put the thread to sleep only when wait_all=false
-        // NOTE: This is supposed to deadlock if no timeout was specified
+        // NOTE: This is supposed to deadlock the current thread if no timeout was specified
         if (!wait_all) {
             wait_thread = true;
             Kernel::WaitCurrentThread(WAITTYPE_SLEEP);
         }
     }
 
-    // Change the thread state to waiting if blocking on all handles...
-    if (wait_thread || wait_all_succeeded) {
+    // If thread should block, then set its state to waiting and then reschedule...
+    if (wait_thread) {
         // Create an event to wake the thread up after the specified nanosecond delay has passed
         Kernel::WakeThreadAfterDelay(Kernel::GetCurrentThread(), nano_seconds);
         Kernel::GetCurrentThread()->SetWaitAll(wait_all);
@@ -199,7 +201,7 @@ static Result WaitSynchronizationN(s32* out, Handle* handles, s32 handle_count,
         HLE::Reschedule(__func__);
 
         // NOTE: output of this SVC will be set later depending on how the thread resumes
-        return RESULT_DUMMY.raw;
+        return 0xDEADBEEF;
     }
 
     // Acquire objects if we did not wait...