tinymist_query/
code_context.rs

1use std::ops::Deref;
2
3use comemo::Track;
4use serde::{Deserialize, Serialize};
5use tinymist_analysis::analyze_expr;
6use tinymist_project::{DiagnosticFormat, PathPattern};
7use tinymist_std::error::prelude::*;
8use tinymist_world::vfs::WorkspaceResolver;
9use tinymist_world::{EntryReader, EntryState, ShadowApi, diag::print_diagnostics_to_string};
10use typst::diag::{At, SourceResult};
11use typst::foundations::{Args, Dict, NativeFunc, eco_format};
12use typst::syntax::Span;
13use typst::utils::LazyHash;
14use typst::{
15    foundations::{Bytes, IntoValue, StyleChain},
16    text::TextElem,
17};
18use typst_shim::eval::{Eval, Vm};
19use typst_shim::syntax::LinkedNodeExt;
20
21use crate::{
22    prelude::*,
23    syntax::{InterpretMode, interpret_mode_at},
24};
25
26/// A query to get the mode at a specific position in a text document.
27#[derive(Debug, Clone, Deserialize)]
28#[serde(tag = "kind", rename_all = "camelCase")]
29pub enum InteractCodeContextQuery {
30    /// (Experimental) Evaluate a path expression at a specific position in a
31    /// text document.
32    PathAt {
33        /// Code to evaluate. If the code starts with `{` and ends with `}`, it
34        /// will be evaluated as a code expression, otherwise it will be
35        /// evaluated as a path pattern.
36        ///
37        /// ## Example
38        ///
39        /// evaluate a path pattern, which could use following definitions:
40        ///
41        /// ```plain
42        /// $root/x/$dir/../$name // is evaluated as
43        /// /path/to/root/x/dir/../main
44        /// ```
45        ///
46        /// ## Example
47        ///
48        /// evaluate a code expression, which could use following definitions:
49        /// - `root`: the root of the workspace
50        /// - `dir`: the directory of the current file
51        /// - `name`: the name of the current file
52        /// - `join(a, b, ...)`: join the arguments with the path separator
53        ///
54        /// ```plain
55        /// { join(root, "x", dir, "y", name) } // is evaluated as
56        /// /path/to/root/x/dir/y/main
57        /// ```
58        code: String,
59        /// The extra `sys.inputs` for the code expression.
60        inputs: Dict,
61    },
62    /// Get the mode at a specific position in a text document.
63    ModeAt {
64        /// The position inside the text document.
65        position: LspPosition,
66    },
67    /// Get the style at a specific position in a text document.
68    StyleAt {
69        /// The position inside the text document.
70        position: LspPosition,
71        /// Style to query
72        style: Vec<String>,
73    },
74}
75
76/// A response to a `InteractCodeContextQuery`.
77#[derive(Debug, Clone, Serialize, Deserialize)]
78#[serde(tag = "kind", rename_all = "camelCase")]
79pub enum InteractCodeContextResponse {
80    /// Evaluate a path expression at a specific position in a text document.
81    PathAt(QueryResult<serde_json::Value>),
82    /// Get the mode at a specific position in a text document.
83    ModeAt {
84        /// The mode at the requested position.
85        mode: InterpretMode,
86    },
87    /// Get the style at a specific position in a text document.
88    StyleAt {
89        /// The style at the requested position.
90        style: Vec<Option<JsonValue>>,
91    },
92}
93
94/// A request to get the code context of a text document.
95#[derive(Debug, Clone, Deserialize)]
96#[serde(tag = "kind")]
97pub struct InteractCodeContextRequest {
98    /// The path to the text document.
99    pub path: PathBuf,
100    /// The queries to execute.
101    pub query: Vec<Option<InteractCodeContextQuery>>,
102}
103
104impl SemanticRequest for InteractCodeContextRequest {
105    type Response = Vec<Option<InteractCodeContextResponse>>;
106
107    fn request(self, ctx: &mut LocalContext) -> Option<Self::Response> {
108        let mut responses = Vec::new();
109
110        let source = ctx.source_by_path(&self.path).ok()?;
111
112        for query in self.query {
113            responses.push(query.and_then(|query| match query {
114                InteractCodeContextQuery::PathAt { code, inputs: base } => {
115                    let res = eval_path_expr(ctx, &code, base)?;
116                    Some(InteractCodeContextResponse::PathAt(res))
117                }
118                InteractCodeContextQuery::ModeAt { position } => {
119                    let cursor = ctx.to_typst_pos(position, &source)?;
120                    let mode = Self::mode_at(&source, cursor)?;
121                    Some(InteractCodeContextResponse::ModeAt { mode })
122                }
123                InteractCodeContextQuery::StyleAt { position, style } => {
124                    let mut world = ctx.world().clone();
125                    log::info!(
126                        "style at position {position:?} . {style:?} when main is {:?}",
127                        world.main()
128                    );
129                    let cursor = ctx.to_typst_pos(position, &source)?;
130                    let root = LinkedNode::new(source.root());
131                    let mut leaf = root.leaf_at_compat(cursor)?;
132                    log::info!("style at leaf {leaf:?} . {style:?}");
133
134                    if !matches!(leaf.kind(), SyntaxKind::Text | SyntaxKind::MathText) {
135                        return None;
136                    }
137
138                    if matches!(leaf.parent_kind(), Some(SyntaxKind::Raw)) {
139                        leaf = leaf.parent()?.clone();
140                    }
141
142                    let mode = Self::mode_at(&source, cursor);
143                    if !matches!(
144                        mode,
145                        Some(InterpretMode::Code | InterpretMode::Markup | InterpretMode::Math)
146                    ) {
147                        leaf = leaf.parent()?.clone();
148                    }
149                    let mut mapped_source = source.clone();
150                    let (with, offset) = match mode {
151                        Some(InterpretMode::Code) => ("context text.font", 8),
152                        _ => ("#context text.font", 10),
153                    };
154                    let start = leaf.range().start;
155                    mapped_source.edit(leaf.range(), with);
156
157                    let _ = world.map_shadow_by_id(
158                        mapped_source.id(),
159                        Bytes::new(mapped_source.text().as_bytes().to_vec()),
160                    );
161                    world.take_db();
162
163                    let root = LinkedNode::new(mapped_source.root());
164                    let leaf = root.leaf_at_compat(start + offset)?;
165
166                    log::info!("style at new_leaf {leaf:?} . {style:?}");
167
168                    let mut cursor_styles = analyze_expr(&world, &leaf)
169                        .iter()
170                        .filter_map(|s| s.1.clone())
171                        .collect::<Vec<_>>();
172                    cursor_styles.sort_by_key(|x| x.as_slice().len());
173                    log::info!("style at styles {cursor_styles:?} . {style:?}");
174                    let cursor_style = cursor_styles.into_iter().next_back().unwrap_or_default();
175                    let cursor_style = StyleChain::new(&cursor_style);
176
177                    log::info!("style at style {cursor_style:?} . {style:?}");
178
179                    let style = style
180                        .iter()
181                        .map(|style| Self::style_at(cursor_style, style))
182                        .collect();
183                    let _ = world.map_shadow_by_id(
184                        mapped_source.id(),
185                        Bytes::new(source.text().as_bytes().to_vec()),
186                    );
187
188                    Some(InteractCodeContextResponse::StyleAt { style })
189                }
190            }));
191        }
192
193        Some(responses)
194    }
195}
196
197impl InteractCodeContextRequest {
198    fn mode_at(source: &Source, pos: usize) -> Option<InterpretMode> {
199        // Smart special cases that is definitely at markup
200        if pos == 0 || pos >= source.text().len() {
201            return Some(InterpretMode::Markup);
202        }
203
204        // Get mode
205        let root = LinkedNode::new(source.root());
206        Some(interpret_mode_at(root.leaf_at_compat(pos).as_ref()))
207    }
208
209    fn style_at(cursor_style: StyleChain, style: &str) -> Option<JsonValue> {
210        match style {
211            "text.font" => {
212                let font = cursor_style.get_cloned(TextElem::font).into_value();
213                serde_json::to_value(font).ok()
214            }
215            _ => None,
216        }
217    }
218}
219
220fn eval_path_expr(
221    ctx: &mut LocalContext,
222    code: &str,
223    inputs: Dict,
224) -> Option<QueryResult<serde_json::Value>> {
225    let entry = ctx.world().entry_state();
226    let path = if code.starts_with("{") && code.ends_with("}") {
227        let id = entry
228            .select_in_workspace(Path::new("/__path__.typ"))
229            .main()?;
230
231        let inputs = make_sys(&entry, ctx.world().inputs(), inputs);
232        let (inputs, root, dir, name) = match inputs {
233            Some(EvalSysCtx {
234                inputs,
235                root,
236                dir,
237                name,
238            }) => (Some(inputs), Some(root), dir, Some(name)),
239            None => (None, None, None, None),
240        };
241
242        let mut world = ctx.world().task(tinymist_world::TaskInputs {
243            entry: None,
244            inputs,
245        });
246        // todo: bad performance
247        world.take_db();
248        let _ = world.map_shadow_by_id(id, Bytes::from_string(code.to_owned()));
249
250        tinymist_analysis::upstream::with_vm((&world as &dyn World).track(), |vm| {
251            define_val(vm, "join", Value::Func(join::data().into()));
252            for (key, value) in [("root", root), ("dir", dir), ("name", name)] {
253                if let Some(value) = value {
254                    define_val(vm, key, value);
255                }
256            }
257
258            let mut expr = typst::syntax::parse_code(code);
259            let span = Span::from_range(id, 0..code.len());
260            expr.synthesize(span);
261
262            let expr = match expr.cast::<ast::Code>() {
263                Some(v) => v,
264                None => bail!(
265                    "code is not a valid code expression: kind={:?}",
266                    expr.kind()
267                ),
268            };
269            match expr.eval(vm) {
270                Ok(value) => serde_json::to_value(value).context_ut("failed to serialize path"),
271                Err(e) => {
272                    let res =
273                        print_diagnostics_to_string(&world, e.iter(), DiagnosticFormat::Human);
274                    let err = res.unwrap_or_else(|e| e);
275                    bail!("failed to evaluate path expression: {err}")
276                }
277            }
278        })
279    } else {
280        PathPattern::new(code)
281            .substitute(&entry)
282            .context_ut("failed to substitute path pattern")
283            .and_then(|path| {
284                serde_json::to_value(path.deref()).context_ut("failed to serialize path")
285            })
286    };
287    Some(path.into())
288}
289
290#[derive(Debug, Clone, Hash)]
291struct EvalSysCtx {
292    inputs: Arc<LazyHash<Dict>>,
293    root: Value,
294    dir: Option<Value>,
295    name: Value,
296}
297
298#[comemo::memoize]
299fn make_sys(entry: &EntryState, base: Arc<LazyHash<Dict>>, inputs: Dict) -> Option<EvalSysCtx> {
300    let root = entry.root();
301    let main = entry.main();
302
303    log::debug!("Check path {main:?} and root {root:?}");
304
305    let (root, main) = root.zip(main)?;
306
307    // Files in packages are not exported
308    if WorkspaceResolver::is_package_file(main) {
309        return None;
310    }
311    // Files without a path are not exported
312    let path = main.vpath().resolve(&root)?;
313
314    // todo: handle untitled path
315    if path.strip_prefix("/untitled").is_ok() {
316        return None;
317    }
318
319    let path = path.strip_prefix(&root).ok()?;
320    let dir = path.parent();
321    let file_name = path.file_name().unwrap_or_default();
322
323    let root = Value::Str(root.to_string_lossy().into());
324
325    let dir = dir.map(|d| Value::Str(d.to_string_lossy().into()));
326
327    let name = file_name.to_string_lossy();
328    let name = name.as_ref().strip_suffix(".typ").unwrap_or(name.as_ref());
329    let name = Value::Str(name.into());
330
331    let mut dict = base.as_ref().deref().clone();
332    for (key, value) in inputs {
333        dict.insert(key, value);
334    }
335    dict.insert("root".into(), root.clone());
336    if let Some(dir) = &dir {
337        dict.insert("dir".into(), dir.clone());
338    }
339    dict.insert("name".into(), name.clone());
340
341    Some(EvalSysCtx {
342        inputs: Arc::new(LazyHash::new(dict)),
343        root,
344        dir,
345        name,
346    })
347}
348
349fn define_val(vm: &mut Vm, name: &str, value: Value) {
350    let ident = SyntaxNode::leaf(SyntaxKind::Ident, name);
351    vm.define(ident.cast::<ast::Ident>().unwrap(), value);
352}
353
354#[typst_macros::func(title = "Join function")]
355fn join(args: &mut Args) -> SourceResult<Value> {
356    let pos = args.take().to_pos();
357    let mut res = PathBuf::new();
358    for arg in pos {
359        match arg {
360            Value::Str(s) => res.push(s.as_str()),
361            _ => {
362                return Err(eco_format!("join argument is not a string: {arg:?}")).at(args.span);
363            }
364        };
365    }
366    Ok(Value::Str(res.to_string_lossy().into()))
367}
368
369/// A result of a query.
370#[derive(Debug, Clone, Serialize, Deserialize)]
371#[serde(untagged)]
372pub enum QueryResult<T> {
373    /// A successful result.
374    Success {
375        /// The value of the result.
376        value: T,
377    },
378    /// An error result.
379    Error {
380        /// The error message.
381        error: EcoString,
382    },
383}
384
385impl<T> QueryResult<T> {
386    /// Creates a successful result.
387    pub fn success(value: T) -> Self {
388        Self::Success { value }
389    }
390
391    /// Creates an error result.
392    pub fn error(error: EcoString) -> Self {
393        Self::Error { error }
394    }
395}
396
397impl<T, E: std::error::Error> From<Result<T, E>> for QueryResult<T> {
398    fn from(value: Result<T, E>) -> Self {
399        match value {
400            Ok(value) => QueryResult::success(value),
401            Err(error) => QueryResult::error(eco_format!("{error}")),
402        }
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use typst::foundations::dict;
409
410    use super::*;
411    use crate::tests::*;
412
413    #[test]
414    fn test() {
415        snapshot_testing("code_context_path_at", &|ctx, path| {
416            let patterns = [
417                "$root/$dir/$name",
418                "$root/$name",
419                "$root/assets",
420                "$root/assets/$name",
421                r#"{ join(root, "x", dir, "y", name) }"#,
422                r#"{ join(root, 1) }"#,
423                r#"{ join(roo, 1) }"#,
424            ];
425            let inp = [
426                dict! {
427                    "x-path-context" => "vscode-paste",
428                    "x-path-input-uri" => "https://huh.io/img.png",
429                    "x-path-input-name" => "img.png",
430                },
431                dict! {
432                    "x-path-context" => "vscode-paste",
433                    "x-path-input-uri" => "https://huh.io/text.md",
434                    "x-path-input-name" => "text.md",
435                },
436            ];
437
438            let cases = patterns
439                .iter()
440                .map(|pat| (*pat, inp[0].clone()))
441                .chain(inp.iter().map(|inp| {
442                    (
443                        r#"{ import "/resolve.typ": resolve; resolve(join, root, dir, name) }"#,
444                        inp.clone(),
445                    )
446                }));
447
448            let result = cases
449                .map(|(code, inputs)| {
450                    let request = InteractCodeContextRequest {
451                        path: path.clone(),
452                        query: vec![Some(InteractCodeContextQuery::PathAt {
453                            code: code.to_string(),
454                            inputs: inputs.clone(),
455                        })],
456                    };
457                    json!({ "code": code, "inputs": inputs, "response": request.request(ctx) })
458                })
459                .collect::<Vec<_>>();
460            assert_snapshot!(JsonRepr::new_redacted(result, &REDACT_LOC));
461        });
462    }
463}