use std::{
io::{self, BufRead, Read, Write},
thread,
};
use crossbeam_channel::{bounded, unbounded, Receiver, Sender};
use crate::{Connection, ConnectionRx, ConnectionTx, GetMessageKind, Message};
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "clap", derive(clap::Parser))]
pub struct MirrorArgs {
#[cfg_attr(feature = "clap", clap(long, default_value = "", value_name = "FILE"))]
pub mirror: String,
#[cfg_attr(feature = "clap", clap(long, default_value = "", value_name = "FILE"))]
pub replay: String,
}
pub fn with_stdio_transport<M: TryFrom<Message, Error = anyhow::Error> + GetMessageKind>(
args: MirrorArgs,
f: impl FnOnce(Connection<M>) -> anyhow::Result<()>,
) -> anyhow::Result<()> {
with_stdio_transport_impl(args, M::get_message_kind(), |conn| f(conn.into()))
}
fn with_stdio_transport_impl(
args: MirrorArgs,
kind: crate::MessageKind,
f: impl FnOnce(Connection<Message>) -> anyhow::Result<()>,
) -> anyhow::Result<()> {
let replay = args.replay.clone();
let mirror = args.mirror.clone();
let i = move || -> Box<dyn BufRead> {
if !replay.is_empty() {
let file = std::fs::File::open(&replay).unwrap();
let file = std::io::BufReader::new(file);
Box::new(file)
} else if mirror.is_empty() {
let stdin = std::io::stdin().lock();
Box::new(stdin)
} else {
let file = std::fs::File::create(&mirror).unwrap();
let stdin = std::io::stdin().lock();
Box::new(MirrorWriter(stdin, file, std::sync::Once::new()))
}
};
let o = || std::io::stdout().lock();
let (event_sender, event_receiver) = unbounded::<crate::Event>();
let (lsp_sender, lsp_receiver, io_threads) = io_transport(kind, i, o);
let connection = Connection {
sender: ConnectionTx {
event: event_sender,
lsp: lsp_sender,
marker: std::marker::PhantomData,
},
receiver: ConnectionRx {
event: event_receiver,
lsp: lsp_receiver,
marker: std::marker::PhantomData,
},
};
f(connection)?;
io_threads.join_write()?;
Ok(())
}
pub fn io_transport<I: BufRead, O: Write>(
kind: crate::MessageKind,
inp: impl FnOnce() -> I + Send + Sync + 'static,
out: impl FnOnce() -> O + Send + Sync + 'static,
) -> (Sender<Message>, Receiver<Message>, IoThreads) {
let (writer_sender, writer_receiver) = bounded::<Message>(0);
let writer = thread::spawn(move || {
let mut out = out();
let res = writer_receiver
.into_iter()
.try_for_each(|it| it.write(&mut out));
log::info!("writer thread finished");
res
});
let (reader_sender, reader_receiver) = bounded::<Message>(0);
let reader = thread::spawn(move || {
let mut inp = inp();
let read_impl = match kind {
#[cfg(feature = "lsp")]
crate::MessageKind::Lsp => Message::read_lsp::<I>,
#[cfg(feature = "dap")]
crate::MessageKind::Dap => Message::read_dap::<I>,
};
while let Some(msg) = read_impl(&mut inp)? {
#[cfg(feature = "lsp")]
use crate::LspMessage;
#[cfg(feature = "lsp")]
let is_exit = matches!(&msg, Message::Lsp(LspMessage::Notification(n)) if n.is_exit());
log::trace!("sending message {msg:#?}");
reader_sender
.send(msg)
.expect("receiver was dropped, failed to send a message");
#[cfg(feature = "lsp")]
if is_exit {
break;
}
}
log::info!("reader thread finished");
Ok(())
});
let threads = IoThreads { reader, writer };
(writer_sender, reader_receiver, threads)
}
pub struct IoThreads {
reader: thread::JoinHandle<io::Result<()>>,
writer: thread::JoinHandle<io::Result<()>>,
}
impl IoThreads {
pub fn join(self) -> io::Result<()> {
match self.reader.join() {
Ok(r) => r?,
Err(err) => {
eprintln!("reader panicked!");
std::panic::panic_any(err)
}
}
match self.writer.join() {
Ok(r) => r,
Err(err) => {
eprintln!("writer panicked!");
std::panic::panic_any(err);
}
}
}
pub fn join_write(self) -> io::Result<()> {
match self.writer.join() {
Ok(r) => r,
Err(err) => {
eprintln!("writer panicked!");
std::panic::panic_any(err);
}
}
}
}
struct MirrorWriter<R: Read, W: Write>(R, W, std::sync::Once);
impl<R: Read, W: Write> Read for MirrorWriter<R, W> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let res = self.0.read(buf)?;
if let Err(err) = self.1.write_all(&buf[..res]) {
self.2.call_once(|| {
log::warn!("failed to write to mirror: {err}");
});
}
Ok(res)
}
}
impl<R: Read + BufRead, W: Write> BufRead for MirrorWriter<R, W> {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
self.0.fill_buf()
}
fn consume(&mut self, amt: usize) {
let buf = self.0.fill_buf().unwrap();
if let Err(err) = self.1.write_all(&buf[..amt]) {
self.2.call_once(|| {
log::warn!("failed to write to mirror: {err}");
});
}
self.0.consume(amt);
}
}