22//!
33//! Provides an implementation of the [`blitz_traits::net::NetProvider`] trait.
44
5- use blitz_traits:: net:: { Body , Bytes , NetHandler , NetProvider , NetWaker , Request } ;
5+ // use blitz_traits::net::{Body, Bytes, NetHandler, NetProvider, NetWaker, Request};
6+ use blitz_traits:: net:: { AbortSignal , Body , Bytes , NetHandler , NetProvider , NetWaker , Request } ;
67use data_url:: DataUrl ;
7- use std:: sync:: Arc ;
8+ use std:: { marker :: PhantomData , pin :: Pin , sync:: Arc , task :: Poll } ;
89use tokio:: runtime:: Handle ;
910
1011#[ cfg( feature = "cache" ) ]
@@ -102,16 +103,6 @@ impl Provider {
102103 } )
103104 }
104105
105- async fn fetch_with_handler (
106- client : Client ,
107- request : Request ,
108- handler : Box < dyn NetHandler > ,
109- ) -> Result < ( ) , ProviderError > {
110- let ( response_url, bytes) = Self :: fetch_inner ( client, request) . await ?;
111- handler. bytes ( response_url, bytes) ;
112- Ok ( ( ) )
113- }
114-
115106 #[ allow( clippy:: type_complexity) ]
116107 pub fn fetch_with_callback (
117108 & self ,
@@ -155,7 +146,7 @@ impl Provider {
155146}
156147
157148impl NetProvider for Provider {
158- fn fetch ( & self , doc_id : usize , request : Request , handler : Box < dyn NetHandler > ) {
149+ fn fetch ( & self , doc_id : usize , mut request : Request , handler : Box < dyn NetHandler > ) {
159150 let client = self . client . clone ( ) ;
160151
161152 #[ cfg( feature = "debug_log" ) ]
@@ -166,23 +157,80 @@ impl NetProvider for Provider {
166157 #[ cfg( feature = "debug_log" ) ]
167158 let url = request. url . to_string ( ) ;
168159
169- let _res = Self :: fetch_with_handler ( client, request, handler) . await ;
170-
171- #[ cfg( feature = "debug_log" ) ]
172- if let Err ( e) = _res {
173- eprintln ! ( "Error fetching {url}: {e:?}" ) ;
160+ let signal = request. signal . take ( ) ;
161+ let result = if let Some ( signal) = signal {
162+ AbortFetch :: new (
163+ signal,
164+ Box :: pin ( async move { Self :: fetch_inner ( client, request) . await } ) ,
165+ )
166+ . await
174167 } else {
175- println ! ( "Success {url}" ) ;
176- }
168+ Self :: fetch_inner ( client , request ) . await
169+ } ;
177170
178171 // Call the waker to notify of completed network request
179- waker. wake ( doc_id)
172+ waker. wake ( doc_id) ;
173+
174+ match result {
175+ Ok ( ( response_url, bytes) ) => {
176+ handler. bytes ( response_url, bytes) ;
177+ #[ cfg( feature = "debug_log" ) ]
178+ println ! ( "Success {url}" ) ;
179+ }
180+ Err ( e) => {
181+ #[ cfg( feature = "debug_log" ) ]
182+ eprintln ! ( "Error fetching {url}: {e:?}" ) ;
183+ #[ cfg( not( feature = "debug_log" ) ) ]
184+ let _ = e;
185+ }
186+ } ;
180187 } ) ;
181188 }
182189}
183190
191+ /// A future that is cancellable using an AbortSignal
192+ struct AbortFetch < F , T > {
193+ signal : AbortSignal ,
194+ future : F ,
195+ _rt : PhantomData < T > ,
196+ }
197+
198+ impl < F , T > AbortFetch < F , T > {
199+ fn new ( signal : AbortSignal , future : F ) -> Self {
200+ Self {
201+ signal,
202+ future,
203+ _rt : PhantomData ,
204+ }
205+ }
206+ }
207+
208+ impl < F , T > Future for AbortFetch < F , T >
209+ where
210+ F : Future + Unpin + Send + ' static ,
211+ F :: Output : Send + Into < Result < T , ProviderError > > + ' static ,
212+ T : Unpin ,
213+ {
214+ type Output = Result < T , ProviderError > ;
215+
216+ fn poll (
217+ mut self : std:: pin:: Pin < & mut Self > ,
218+ cx : & mut std:: task:: Context < ' _ > ,
219+ ) -> std:: task:: Poll < Self :: Output > {
220+ if self . signal . aborted ( ) {
221+ return Poll :: Ready ( Err ( ProviderError :: Abort ) ) ;
222+ }
223+
224+ match Pin :: new ( & mut self . future ) . poll ( cx) {
225+ Poll :: Ready ( output) => Poll :: Ready ( output. into ( ) ) ,
226+ Poll :: Pending => Poll :: Pending ,
227+ }
228+ }
229+ }
230+
184231#[ derive( Debug ) ]
185232pub enum ProviderError {
233+ Abort ,
186234 Io ( std:: io:: Error ) ,
187235 DataUrl ( data_url:: DataUrlError ) ,
188236 DataUrlBase64 ( data_url:: forgiving_base64:: InvalidBase64 ) ,
0 commit comments