diff --git a/packages/binding/src/lib.rs b/packages/binding/src/lib.rs index db189ff..b2d5ee1 100644 --- a/packages/binding/src/lib.rs +++ b/packages/binding/src/lib.rs @@ -45,6 +45,7 @@ struct JsDownloadOptions { pub entry_listener: Option, pub retry_time: Option, pub toc_path: Option, + pub registries: Option>, } #[napi(object)] @@ -118,6 +119,12 @@ fn parse_download_options( None }; + let registries = if let Some(registries) = options.registries.take() { + Some(registries) + } else { + None + }; + Ok(DownloadOptions { download_dir: options.download_dir, bucket_count: options.bucket_count as u8, @@ -126,6 +133,7 @@ fn parse_download_options( entry_listener, retry_time: retry_time as u8, toc_path, + registries, }) } @@ -276,7 +284,11 @@ impl JsDownloader { if self.inner.is_some() { return Ok(()); } - let http_pool = HTTPPool::new(self.options.http_concurrent_count).map_err(|e| { + let http_pool = HTTPPool::new( + self.options.http_concurrent_count, + self.options.registries.clone(), + ) + .map_err(|e| { Error::new( Status::FunctionExpected, format!("create reqwester failed: {:?}", e), diff --git a/packages/cli/lib/download_dependency.js b/packages/cli/lib/download_dependency.js index 4e790d8..562d692 100644 --- a/packages/cli/lib/download_dependency.js +++ b/packages/cli/lib/download_dependency.js @@ -46,6 +46,7 @@ async function download(options) { const downloader = new Downloader({ entryListener, productionMode: options.productionMode, + registries: options.registries, }); await downloader.init(); diff --git a/packages/cli/lib/downloader.js b/packages/cli/lib/downloader.js index 7827af7..61078d5 100644 --- a/packages/cli/lib/downloader.js +++ b/packages/cli/lib/downloader.js @@ -18,6 +18,7 @@ class Downloader { * @param {NodeJS.Architecture} [options.arch] - * @param {boolean} [options.productionMode] - * @param {{function(*)}} [options.entryListener] - + * @param {Array} [options.registries] - */ constructor(options) { this.entryListener = options.entryListener; @@ -30,6 +31,7 @@ class Downloader { this.rapidDownloader = this.createRapidDownloader(); this.taskMap = new Map(); this._dumpData = null; + this.registries = options.registries; } async init() { @@ -58,6 +60,7 @@ class Downloader { map: npmCacheConfigPath, index: npmIndexConfigPath, }, + registries: this.registries, }); } diff --git a/packages/downloader/src/download.rs b/packages/downloader/src/download.rs index 2b66611..af19b07 100644 --- a/packages/downloader/src/download.rs +++ b/packages/downloader/src/download.rs @@ -247,7 +247,7 @@ mod test { bucket_path: &str, entry_listener: Option, ) -> Downloader { - let http_pool = HTTPPool::new(1).expect("create http pool failed"); + let http_pool = HTTPPool::new(1, None).expect("create http pool failed"); let store = NpmStore::new( 1, Path::new(bucket_path), diff --git a/packages/downloader/src/http/pool.rs b/packages/downloader/src/http/pool.rs index 03326b1..cc88093 100644 --- a/packages/downloader/src/http/pool.rs +++ b/packages/downloader/src/http/pool.rs @@ -47,14 +47,15 @@ impl Executor for HTTPReqwester { } impl HTTPPool { - pub fn new(max_concurrent: u8) -> ProjectResult { + pub fn new(max_concurrent: u8, registries: Option>) -> ProjectResult { let client_builder = reqwest::ClientBuilder::new() .tcp_keepalive(Duration::from_secs(60)) .connection_verbose(true) .redirect(reqwest::redirect::Policy::limited(10)) .http1_only() .use_rustls_tls(); - let client_builder = HTTPReqwester::prepare_dns_resolve(client_builder)?; + let domains = registries.unwrap_or(vec!["registry.npmmirror.com".to_owned()]); + let client_builder = HTTPReqwester::prepare_dns_resolve(client_builder, domains)?; let client = client_builder.build()?; let client = Arc::new(client); let mut reqwesters = Vec::with_capacity(max_concurrent as usize); @@ -101,7 +102,7 @@ mod test { #[tokio::test] async fn test_pool() { - let pool = HTTPPool::new(2).unwrap(); + let pool = HTTPPool::new(2, None).unwrap(); let (sx, mut rx) = mpsc::channel::(2); let download_handler = tokio::spawn(async move { while let Some(mut response) = rx.recv().await { diff --git a/packages/downloader/src/http/reqwester.rs b/packages/downloader/src/http/reqwester.rs index c10a41b..292e30a 100644 --- a/packages/downloader/src/http/reqwester.rs +++ b/packages/downloader/src/http/reqwester.rs @@ -22,7 +22,7 @@ pub struct HTTPReqwester { } impl HTTPReqwester { - pub fn new() -> TnpmResult { + pub fn new(registries: Option>) -> TnpmResult { let client_builder = reqwest::ClientBuilder::new() .tcp_keepalive(Duration::from_secs(60)) .connection_verbose(true) @@ -30,7 +30,8 @@ impl HTTPReqwester { .http1_only() .use_rustls_tls(); - let client_builder = HTTPReqwester::prepare_dns_resolve(client_builder)?; + let domains = registries.unwrap_or(vec!["registry.npmmirror.com".to_owned()]); + let client_builder = HTTPReqwester::prepare_dns_resolve(client_builder, domains)?; let client = client_builder.build()?; Ok(HTTPReqwester { client: Arc::new(client), @@ -43,9 +44,9 @@ impl HTTPReqwester { pub(crate) fn prepare_dns_resolve( mut client_builder: reqwest::ClientBuilder, + pre_resolve_list: Vec, ) -> Result { // 在大量连接建立时,容易发生 DNS 超时 - let pre_resolve_list = vec!["registry.npmmirror.com"]; let mut client_builder = client_builder; for address in pre_resolve_list { let address_with_port = format!("{}:443", address); @@ -56,7 +57,7 @@ impl HTTPReqwester { format!("not found address for {}", address), ) })?; - client_builder = client_builder.resolve(address, socket_addr); + client_builder = client_builder.resolve(&address, socket_addr); } Ok(client_builder) } @@ -100,7 +101,7 @@ mod test { #[tokio::test] async fn test_download() { - let req = HTTPReqwester::new().unwrap(); + let req = HTTPReqwester::new(None).unwrap(); let stream = req .request( PackageRequestBuilder::new() diff --git a/packages/downloader/src/lib.rs b/packages/downloader/src/lib.rs index ebe47af..05b2d1d 100644 --- a/packages/downloader/src/lib.rs +++ b/packages/downloader/src/lib.rs @@ -47,6 +47,7 @@ pub struct DownloadOptions { pub entry_listener: Option, pub download_timeout: Duration, pub toc_path: Option, + pub registries: Option>, } pub async fn download( @@ -70,7 +71,7 @@ pub async fn download( entry_listener, ) .await?; - let http_pool = HTTPPool::new(http_concurrent_count)?; + let http_pool = HTTPPool::new(http_concurrent_count, opts.registries)?; let toc_index_store = Arc::new(TocIndexStore::new()); let mut downloader = Downloader::new(store, http_pool, toc_index_store.clone(), retry_time); downloader.batch_download(pkg_requests).await?; @@ -110,6 +111,7 @@ mod test_downloader { entry_listener: None, retry_time: 1, toc_path: None, + registries: None, }, ) .await diff --git a/packages/downloader/src/main.rs b/packages/downloader/src/main.rs index 01489dc..f3058d7 100644 --- a/packages/downloader/src/main.rs +++ b/packages/downloader/src/main.rs @@ -22,7 +22,7 @@ use tokio::task::JoinHandle; async fn black_hole_download() { let lock = get_chair_lock(); let pkg_requests = lock.list_all_packages(); - let pool = HTTPPool::new(50).unwrap(); + let pool = HTTPPool::new(50, None).unwrap(); let bucket = 50; let (sx, rx_list) = subscriber(bucket); let handlers: Vec> = rx_list @@ -61,7 +61,7 @@ async fn file_download() { } }); - let http_pool = HTTPPool::new(bucket_size * 2).unwrap(); + let http_pool = HTTPPool::new(bucket_size * 2, None).unwrap(); let store = NpmStore::new( bucket_size, download_dir.as_path(),