Improve download cancellation.

This commit is contained in:
Michael Bucari-Tovo 2022-12-18 21:52:51 -07:00
parent 7fd002d2c9
commit 2024d5e116

View File

@ -39,13 +39,14 @@ namespace AaxDecrypter
public long ContentLength { get; private set; } public long ContentLength { get; private set; }
[JsonIgnore] [JsonIgnore]
public bool IsCancelled { get; private set; } public bool IsCancelled => _cancellationSource.IsCancellationRequested;
#endregion #endregion
#region Private Properties #region Private Properties
private FileStream _writeFile { get; } private FileStream _writeFile { get; }
private FileStream _readFile { get; } private FileStream _readFile { get; }
private CancellationTokenSource _cancellationSource { get; } = new();
private EventWaitHandle _downloadedPiece { get; set; } private EventWaitHandle _downloadedPiece { get; set; }
private Task _backgroundDownloadTask { get; set; } private Task _backgroundDownloadTask { get; set; }
@ -126,7 +127,6 @@ namespace AaxDecrypter
RequestHeaders["Range"] = $"bytes={WritePosition}-"; RequestHeaders["Range"] = $"bytes={WritePosition}-";
} }
/// <summary> Begins downloading <see cref="Uri"/> to <see cref="SaveFilePath"/> in a background thread. </summary> /// <summary> Begins downloading <see cref="Uri"/> to <see cref="SaveFilePath"/> in a background thread. </summary>
/// <returns>The downloader <see cref="Task"/></returns> /// <returns>The downloader <see cref="Task"/></returns>
private Task BeginDownloading() private Task BeginDownloading()
@ -137,13 +137,12 @@ namespace AaxDecrypter
if (ContentLength != 0 && WritePosition > ContentLength) if (ContentLength != 0 && WritePosition > ContentLength)
throw new WebException($"Specified write position (0x{WritePosition:X10}) is larger than {nameof(ContentLength)} (0x{ContentLength:X10})."); throw new WebException($"Specified write position (0x{WritePosition:X10}) is larger than {nameof(ContentLength)} (0x{ContentLength:X10}).");
var request = new HttpRequestMessage(HttpMethod.Get, Uri); var request = new HttpRequestMessage(HttpMethod.Get, Uri);
foreach (var header in RequestHeaders) foreach (var header in RequestHeaders)
request.Headers.Add(header.Key, header.Value); request.Headers.Add(header.Key, header.Value);
var response = new HttpClient().Send(request, HttpCompletionOption.ResponseHeadersRead); var response = new HttpClient().Send(request, HttpCompletionOption.ResponseHeadersRead, _cancellationSource.Token);
if (response.StatusCode != HttpStatusCode.PartialContent) if (response.StatusCode != HttpStatusCode.PartialContent)
throw new WebException($"Server at {Uri.Host} responded with unexpected status code: {response.StatusCode}."); throw new WebException($"Server at {Uri.Host} responded with unexpected status code: {response.StatusCode}.");
@ -153,15 +152,15 @@ namespace AaxDecrypter
if (WritePosition == 0) if (WritePosition == 0)
ContentLength = response.Content.Headers.ContentLength.GetValueOrDefault(); ContentLength = response.Content.Headers.ContentLength.GetValueOrDefault();
var networkStream = response.Content.ReadAsStream(); var networkStream = response.Content.ReadAsStream(_cancellationSource.Token);
_downloadedPiece = new EventWaitHandle(false, EventResetMode.AutoReset); _downloadedPiece = new EventWaitHandle(false, EventResetMode.AutoReset);
//Download the file in the background. //Download the file in the background.
return Task.Run(() => DownloadFile(networkStream)); return DownloadFile(networkStream);
} }
/// <summary> Download <see cref="Uri"/> to <see cref="SaveFilePath"/>.</summary> /// <summary> Download <see cref="Uri"/> to <see cref="SaveFilePath"/>.</summary>
private void DownloadFile(Stream networkStream) private async Task DownloadFile(Stream networkStream)
{ {
var downloadPosition = WritePosition; var downloadPosition = WritePosition;
var nextFlush = downloadPosition + DATA_FLUSH_SZ; var nextFlush = downloadPosition + DATA_FLUSH_SZ;
@ -172,14 +171,14 @@ namespace AaxDecrypter
int bytesRead; int bytesRead;
do do
{ {
bytesRead = networkStream.Read(buff, 0, DOWNLOAD_BUFF_SZ); bytesRead = await networkStream.ReadAsync(buff, 0, DOWNLOAD_BUFF_SZ, _cancellationSource.Token);
_writeFile.Write(buff, 0, bytesRead); await _writeFile.WriteAsync(buff, 0, bytesRead, _cancellationSource.Token);
downloadPosition += bytesRead; downloadPosition += bytesRead;
if (downloadPosition > nextFlush) if (downloadPosition > nextFlush)
{ {
_writeFile.Flush(); await _writeFile.FlushAsync(_cancellationSource.Token);
WritePosition = downloadPosition; WritePosition = downloadPosition;
Update(); Update();
nextFlush = downloadPosition + DATA_FLUSH_SZ; nextFlush = downloadPosition + DATA_FLUSH_SZ;
@ -261,7 +260,7 @@ namespace AaxDecrypter
var toRead = Math.Min(count, Length - Position); var toRead = Math.Min(count, Length - Position);
WaitToPosition(Position + toRead); WaitToPosition(Position + toRead);
return _readFile.Read(buffer, offset, count); return IsCancelled ? 0: _readFile.Read(buffer, offset, count);
} }
public override long Seek(long offset, SeekOrigin origin) public override long Seek(long offset, SeekOrigin origin)
@ -274,7 +273,7 @@ namespace AaxDecrypter
}; };
WaitToPosition(newPosition); WaitToPosition(newPosition);
return _readFile.Position = newPosition; return IsCancelled ? 0 : (_readFile.Position = newPosition);
} }
/// <summary>Blocks until the file has downloaded to at least <paramref name="requiredPosition"/>, then returns. </summary> /// <summary>Blocks until the file has downloaded to at least <paramref name="requiredPosition"/>, then returns. </summary>
@ -285,13 +284,13 @@ namespace AaxDecrypter
&& _backgroundDownloadTask?.IsCompleted is false && _backgroundDownloadTask?.IsCompleted is false
&& !IsCancelled) && !IsCancelled)
{ {
_downloadedPiece.WaitOne(100); _downloadedPiece.WaitOne(50);
} }
} }
public override void Close() public override void Close()
{ {
IsCancelled = true; _cancellationSource.Cancel();
_backgroundDownloadTask?.Wait(); _backgroundDownloadTask?.Wait();
_readFile.Close(); _readFile.Close();