1use std::{
4 io::{self, BufRead, Read, Write},
5 thread,
6};
7
8use crossbeam_channel::{Receiver, Sender, bounded, unbounded};
9
10use crate::{Connection, ConnectionRx, ConnectionTx, GetMessageKind, Message};
11
12#[derive(Debug, Clone, Default)]
34#[cfg_attr(feature = "clap", derive(clap::Parser))]
35pub struct MirrorArgs {
36 #[cfg_attr(feature = "clap", clap(long, default_value = "", value_name = "FILE"))]
38 pub mirror: String,
39 #[cfg_attr(feature = "clap", clap(long, default_value = "", value_name = "FILE"))]
41 pub replay: String,
42}
43
44pub fn with_stdio_transport<M: TryFrom<Message, Error = anyhow::Error> + GetMessageKind>(
46 args: MirrorArgs,
47 f: impl FnOnce(Connection<M>) -> anyhow::Result<()>,
48) -> anyhow::Result<()> {
49 with_stdio_transport_impl(args, M::MESSAGE_KIND, |conn| f(conn.into()))
50}
51
52fn with_stdio_transport_impl(
54 args: MirrorArgs,
55 kind: crate::MessageKind,
56 f: impl FnOnce(Connection<Message>) -> anyhow::Result<()>,
57) -> anyhow::Result<()> {
58 let replay = args.replay.clone();
60 let mirror = args.mirror.clone();
61 let i = move || -> Box<dyn BufRead> {
62 if !replay.is_empty() {
63 let file = std::fs::File::open(&replay).unwrap();
65 let file = std::io::BufReader::new(file);
66 Box::new(file)
67 } else if mirror.is_empty() {
68 let stdin = std::io::stdin().lock();
70 Box::new(stdin)
71 } else {
72 let file = std::fs::File::create(&mirror).unwrap();
73 let stdin = std::io::stdin().lock();
74 Box::new(MirrorWriter(stdin, file, std::sync::Once::new()))
75 }
76 };
77 let o = || std::io::stdout().lock();
78
79 let (event_sender, event_receiver) = unbounded::<crate::Event>();
80
81 let (lsp_sender, lsp_receiver, io_threads) = io_transport(kind, i, o);
84 let connection = Connection {
85 sender: ConnectionTx {
88 event: event_sender,
89 lsp: lsp_sender,
90 marker: std::marker::PhantomData,
91 },
92 receiver: ConnectionRx {
93 event: event_receiver,
94 lsp: lsp_receiver,
95 marker: std::marker::PhantomData,
96 },
97 };
98
99 f(connection)?;
100
101 io_threads.join_write()?;
102
103 Ok(())
104}
105
106pub fn io_transport<I: BufRead, O: Write>(
119 kind: crate::MessageKind,
120 inp: impl FnOnce() -> I + Send + Sync + 'static,
121 out: impl FnOnce() -> O + Send + Sync + 'static,
122) -> (Sender<Message>, Receiver<Message>, IoThreads) {
123 let (writer_sender, writer_receiver) = bounded::<Message>(0);
124 let writer = thread::spawn(move || {
125 let mut out = out();
126 let res = writer_receiver
127 .into_iter()
128 .try_for_each(|it| it.write(&mut out));
129
130 log::info!("writer thread finished");
131 res
132 });
133 let (reader_sender, reader_receiver) = bounded::<Message>(0);
134 let reader = thread::spawn(move || {
135 let mut inp = inp();
136 let read_impl = match kind {
137 #[cfg(feature = "lsp")]
138 crate::MessageKind::Lsp => Message::read_lsp::<I>,
139 #[cfg(feature = "dap")]
140 crate::MessageKind::Dap => Message::read_dap::<I>,
141 };
142 while let Some(msg) = read_impl(&mut inp)? {
143 #[cfg(feature = "lsp")]
144 use crate::LspMessage;
145 #[cfg(feature = "lsp")]
146 let is_exit = matches!(&msg, Message::Lsp(LspMessage::Notification(n)) if n.is_exit());
147
148 log::trace!("sending message {msg:#?}");
149 reader_sender
150 .send(msg)
151 .expect("receiver was dropped, failed to send a message");
152
153 #[cfg(feature = "lsp")]
154 if is_exit {
155 break;
156 }
157 }
158
159 log::info!("reader thread finished");
160 Ok(())
161 });
162 let threads = IoThreads { reader, writer };
163 (writer_sender, reader_receiver, threads)
164}
165
166pub struct IoThreads {
168 reader: thread::JoinHandle<io::Result<()>>,
169 writer: thread::JoinHandle<io::Result<()>>,
170}
171
172impl IoThreads {
173 pub fn join(self) -> io::Result<()> {
175 match self.reader.join() {
176 Ok(r) => r?,
177 Err(err) => {
178 eprintln!("reader panicked!");
179 std::panic::panic_any(err)
180 }
181 }
182 match self.writer.join() {
183 Ok(r) => r,
184 Err(err) => {
185 eprintln!("writer panicked!");
186 std::panic::panic_any(err);
187 }
188 }
189 }
190
191 pub fn join_write(self) -> io::Result<()> {
193 match self.writer.join() {
194 Ok(r) => r,
195 Err(err) => {
196 eprintln!("writer panicked!");
197 std::panic::panic_any(err);
198 }
199 }
200 }
201}
202
203struct MirrorWriter<R: Read, W: Write>(R, W, std::sync::Once);
204
205impl<R: Read, W: Write> Read for MirrorWriter<R, W> {
206 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
207 let res = self.0.read(buf)?;
208
209 if let Err(err) = self.1.write_all(&buf[..res]) {
210 self.2.call_once(|| {
211 log::warn!("failed to write to mirror: {err}");
212 });
213 }
214
215 Ok(res)
216 }
217}
218
219impl<R: Read + BufRead, W: Write> BufRead for MirrorWriter<R, W> {
220 fn fill_buf(&mut self) -> io::Result<&[u8]> {
221 self.0.fill_buf()
222 }
223
224 fn consume(&mut self, amt: usize) {
225 let buf = self.0.fill_buf().unwrap();
226
227 if let Err(err) = self.1.write_all(&buf[..amt]) {
228 self.2.call_once(|| {
229 log::warn!("failed to write to mirror: {err}");
230 });
231 }
232
233 self.0.consume(amt);
234 }
235}