sync_ls/
transport.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
//! Transport layer for LSP messages.

use std::{
    io::{self, BufRead, Read, Write},
    thread,
};

use crossbeam_channel::{bounded, unbounded, Receiver, Sender};

use crate::{Connection, ConnectionRx, ConnectionTx, GetMessageKind, Message};

/// Convenience cli arguments for setting up a transport with an optional mirror
/// or replay file.
///
/// The `mirror` argument will write the stdin to the file.
/// The `replay` argument will read the file as input.
///
/// # Example
///
/// The example below shows the typical usage of the `MirrorArgs` struct.
/// It records an LSP or DAP session and replays it to compare the output.
///
/// If the language server has stable output, the replayed output should be the
/// same.
///
/// ```bash
/// $ my-lsp --mirror /tmp/mirror.log > responses.txt
/// $ ls /tmp
/// mirror.log
/// $ my-lsp --replay /tmp/mirror.log > responses-replayed.txt
/// $ diff responses.txt responses-replayed.txt
/// ```
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "clap", derive(clap::Parser))]
pub struct MirrorArgs {
    /// Mirror the stdin to the file
    #[cfg_attr(feature = "clap", clap(long, default_value = "", value_name = "FILE"))]
    pub mirror: String,
    /// Replay input from the file
    #[cfg_attr(feature = "clap", clap(long, default_value = "", value_name = "FILE"))]
    pub replay: String,
}

/// Note that we must have our logging only write out to stderr.
pub fn with_stdio_transport<M: TryFrom<Message, Error = anyhow::Error> + GetMessageKind>(
    args: MirrorArgs,
    f: impl FnOnce(Connection<M>) -> anyhow::Result<()>,
) -> anyhow::Result<()> {
    with_stdio_transport_impl(args, M::get_message_kind(), |conn| f(conn.into()))
}

/// Note that we must have our logging only write out to stderr.
fn with_stdio_transport_impl(
    args: MirrorArgs,
    kind: crate::MessageKind,
    f: impl FnOnce(Connection<Message>) -> anyhow::Result<()>,
) -> anyhow::Result<()> {
    // Set up input and output
    let replay = args.replay.clone();
    let mirror = args.mirror.clone();
    let i = move || -> Box<dyn BufRead> {
        if !replay.is_empty() {
            // Get input from file
            let file = std::fs::File::open(&replay).unwrap();
            let file = std::io::BufReader::new(file);
            Box::new(file)
        } else if mirror.is_empty() {
            // Get input from stdin
            let stdin = std::io::stdin().lock();
            Box::new(stdin)
        } else {
            let file = std::fs::File::create(&mirror).unwrap();
            let stdin = std::io::stdin().lock();
            Box::new(MirrorWriter(stdin, file, std::sync::Once::new()))
        }
    };
    let o = || std::io::stdout().lock();

    let (event_sender, event_receiver) = unbounded::<crate::Event>();

    // Create the transport. Includes the stdio (stdin and stdout) versions but this
    // could also be implemented to use sockets or HTTP.
    let (lsp_sender, lsp_receiver, io_threads) = io_transport(kind, i, o);
    let connection = Connection {
        // lsp_sender,
        // lsp_receiver,
        sender: ConnectionTx {
            event: event_sender,
            lsp: lsp_sender,
            marker: std::marker::PhantomData,
        },
        receiver: ConnectionRx {
            event: event_receiver,
            lsp: lsp_receiver,
            marker: std::marker::PhantomData,
        },
    };

    f(connection)?;

    io_threads.join_write()?;

    Ok(())
}

/// Creates an LSP connection via io.
///
/// # Example
///
/// ```
/// use std::io::{stdin, stdout};
/// use sync_ls::{Message, MessageKind, transport::{io_transport, IoThreads}};
/// use crossbeam_channel::{bounded, Receiver, Sender};
/// pub fn stdio_transport() -> (Sender<Message>, Receiver<Message>, IoThreads) {
///   io_transport(MessageKind::Lsp, || stdin().lock(), || stdout().lock())
/// }
/// ```
pub fn io_transport<I: BufRead, O: Write>(
    kind: crate::MessageKind,
    inp: impl FnOnce() -> I + Send + Sync + 'static,
    out: impl FnOnce() -> O + Send + Sync + 'static,
) -> (Sender<Message>, Receiver<Message>, IoThreads) {
    let (writer_sender, writer_receiver) = bounded::<Message>(0);
    let writer = thread::spawn(move || {
        let mut out = out();
        let res = writer_receiver
            .into_iter()
            .try_for_each(|it| it.write(&mut out));

        log::info!("writer thread finished");
        res
    });
    let (reader_sender, reader_receiver) = bounded::<Message>(0);
    let reader = thread::spawn(move || {
        let mut inp = inp();
        let read_impl = match kind {
            #[cfg(feature = "lsp")]
            crate::MessageKind::Lsp => Message::read_lsp::<I>,
            #[cfg(feature = "dap")]
            crate::MessageKind::Dap => Message::read_dap::<I>,
        };
        while let Some(msg) = read_impl(&mut inp)? {
            #[cfg(feature = "lsp")]
            use crate::LspMessage;
            #[cfg(feature = "lsp")]
            let is_exit = matches!(&msg, Message::Lsp(LspMessage::Notification(n)) if n.is_exit());

            log::trace!("sending message {msg:#?}");
            reader_sender
                .send(msg)
                .expect("receiver was dropped, failed to send a message");

            #[cfg(feature = "lsp")]
            if is_exit {
                break;
            }
        }

        log::info!("reader thread finished");
        Ok(())
    });
    let threads = IoThreads { reader, writer };
    (writer_sender, reader_receiver, threads)
}

/// A pair of threads for reading and writing LSP messages.
pub struct IoThreads {
    reader: thread::JoinHandle<io::Result<()>>,
    writer: thread::JoinHandle<io::Result<()>>,
}

impl IoThreads {
    /// Waits for the reader and writer threads to finish.
    pub fn join(self) -> io::Result<()> {
        match self.reader.join() {
            Ok(r) => r?,
            Err(err) => {
                eprintln!("reader panicked!");
                std::panic::panic_any(err)
            }
        }
        match self.writer.join() {
            Ok(r) => r,
            Err(err) => {
                eprintln!("writer panicked!");
                std::panic::panic_any(err);
            }
        }
    }

    /// Waits for the writer threads to finish.
    pub fn join_write(self) -> io::Result<()> {
        match self.writer.join() {
            Ok(r) => r,
            Err(err) => {
                eprintln!("writer panicked!");
                std::panic::panic_any(err);
            }
        }
    }
}

struct MirrorWriter<R: Read, W: Write>(R, W, std::sync::Once);

impl<R: Read, W: Write> Read for MirrorWriter<R, W> {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        let res = self.0.read(buf)?;

        if let Err(err) = self.1.write_all(&buf[..res]) {
            self.2.call_once(|| {
                log::warn!("failed to write to mirror: {err}");
            });
        }

        Ok(res)
    }
}

impl<R: Read + BufRead, W: Write> BufRead for MirrorWriter<R, W> {
    fn fill_buf(&mut self) -> io::Result<&[u8]> {
        self.0.fill_buf()
    }

    fn consume(&mut self, amt: usize) {
        let buf = self.0.fill_buf().unwrap();

        if let Err(err) = self.1.write_all(&buf[..amt]) {
            self.2.call_once(|| {
                log::warn!("failed to write to mirror: {err}");
            });
        }

        self.0.consume(amt);
    }
}