ApiAuthenticationMiddleware performance improvements

Previously we've used one semaphore per all ongoing authentication attempts, which is suboptimal given the existence of a lot of consumers, including ongoing (D)DoS or distributed bruteforce attack. ASF should be as resistant to that as possible, therefore it makes sense to replace the global semaphore with per-IP semaphore (actually task), that can control the access just as well, without stopping other consumers from accessing the same authentication process concurrently.
This commit is contained in:
Archi 2021-08-24 01:37:14 +02:00
parent 47855ca705
commit 69e2a3590c
No known key found for this signature in database
GPG key ID: 6B138B4C64555AEA
2 changed files with 31 additions and 14 deletions

View file

@ -19,6 +19,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#if NETFRAMEWORK
using JustArchiNET.Madness;
#endif
using System;
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
@ -45,7 +48,7 @@ namespace ArchiSteamFarm.IPC.Integration {
private const byte FailedAuthorizationsCooldownInHours = 1;
private const byte MaxFailedAuthorizationAttempts = 5;
private static readonly SemaphoreSlim AuthorizationSemaphore = new(1, 1);
private static readonly ConcurrentDictionary<IPAddress, Task> AuthorizationTasks = new();
private static readonly Timer ClearFailedAuthorizationsTimer = new(ClearFailedAuthorizations);
private static readonly ConcurrentDictionary<IPAddress, byte> FailedAuthorizations = new();
@ -150,23 +153,37 @@ namespace ArchiSteamFarm.IPC.Integration {
bool authorized = ipcPassword == inputHash;
await AuthorizationSemaphore.WaitAsync().ConfigureAwait(false);
while (true) {
if (AuthorizationTasks.TryGetValue(clientIP, out Task? task)) {
await task.ConfigureAwait(false);
try {
bool hasFailedAuthorizations = FailedAuthorizations.TryGetValue(clientIP, out attempts);
if (hasFailedAuthorizations && (attempts >= MaxFailedAuthorizationAttempts)) {
return (HttpStatusCode.Forbidden, false);
continue;
}
if (!authorized) {
FailedAuthorizations[clientIP] = hasFailedAuthorizations ? ++attempts : (byte) 1;
TaskCompletionSource taskCompletionSource = new();
if (!AuthorizationTasks.TryAdd(clientIP, taskCompletionSource.Task)) {
continue;
}
} finally {
AuthorizationSemaphore.Release();
try {
bool hasFailedAuthorizations = FailedAuthorizations.TryGetValue(clientIP, out attempts);
if (hasFailedAuthorizations && (attempts >= MaxFailedAuthorizationAttempts)) {
return (HttpStatusCode.Forbidden, false);
}
if (!authorized) {
FailedAuthorizations[clientIP] = hasFailedAuthorizations ? ++attempts : (byte) 1;
}
} finally {
AuthorizationTasks.TryRemove(clientIP, out _);
taskCompletionSource.SetResult();
}
return (authorized ? HttpStatusCode.OK : HttpStatusCode.Unauthorized, true);
}
return (authorized ? HttpStatusCode.OK : HttpStatusCode.Unauthorized, true);
}
}
}

View file

@ -29,7 +29,7 @@
</ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'net48'">
<PackageVersion Include="JustArchiNET.Madness" Version="1.2.0" />
<PackageVersion Include="JustArchiNET.Madness" Version="1.3.0" />
<PackageVersion Include="Microsoft.AspNetCore.Cors" Version="2.2.0" />
<PackageVersion Include="Microsoft.AspNetCore.Diagnostics" Version="2.2.0" />
<PackageVersion Include="Microsoft.AspNetCore.HttpOverrides" Version="2.2.0" />