1use std::ops::Deref;
2
3use comemo::Track;
4use serde::{Deserialize, Serialize};
5use tinymist_analysis::analyze_expr;
6use tinymist_project::{DiagnosticFormat, PathPattern};
7use tinymist_std::error::prelude::*;
8use tinymist_world::vfs::WorkspaceResolver;
9use tinymist_world::{EntryReader, EntryState, ShadowApi, diag::print_diagnostics_to_string};
10use typst::diag::{At, SourceResult};
11use typst::foundations::{Args, Dict, NativeFunc, eco_format};
12use typst::syntax::Span;
13use typst::utils::LazyHash;
14use typst::{
15 foundations::{Bytes, IntoValue, StyleChain},
16 text::TextElem,
17};
18use typst_shim::eval::{Eval, Vm};
19use typst_shim::syntax::LinkedNodeExt;
20
21use crate::{
22 prelude::*,
23 syntax::{InterpretMode, interpret_mode_at},
24};
25
26#[derive(Debug, Clone, Deserialize)]
28#[serde(tag = "kind", rename_all = "camelCase")]
29pub enum InteractCodeContextQuery {
30 PathAt {
33 code: String,
59 inputs: Dict,
61 },
62 ModeAt {
64 position: LspPosition,
66 },
67 StyleAt {
69 position: LspPosition,
71 style: Vec<String>,
73 },
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78#[serde(tag = "kind", rename_all = "camelCase")]
79pub enum InteractCodeContextResponse {
80 PathAt(QueryResult<serde_json::Value>),
82 ModeAt {
84 mode: InterpretMode,
86 },
87 StyleAt {
89 style: Vec<Option<JsonValue>>,
91 },
92}
93
94#[derive(Debug, Clone, Deserialize)]
96#[serde(tag = "kind")]
97pub struct InteractCodeContextRequest {
98 pub path: PathBuf,
100 pub query: Vec<Option<InteractCodeContextQuery>>,
102}
103
104impl SemanticRequest for InteractCodeContextRequest {
105 type Response = Vec<Option<InteractCodeContextResponse>>;
106
107 fn request(self, ctx: &mut LocalContext) -> Option<Self::Response> {
108 let mut responses = Vec::new();
109
110 let source = ctx.source_by_path(&self.path).ok()?;
111
112 for query in self.query {
113 responses.push(query.and_then(|query| match query {
114 InteractCodeContextQuery::PathAt { code, inputs: base } => {
115 let res = eval_path_expr(ctx, &code, base)?;
116 Some(InteractCodeContextResponse::PathAt(res))
117 }
118 InteractCodeContextQuery::ModeAt { position } => {
119 let cursor = ctx.to_typst_pos(position, &source)?;
120 let mode = Self::mode_at(&source, cursor)?;
121 Some(InteractCodeContextResponse::ModeAt { mode })
122 }
123 InteractCodeContextQuery::StyleAt { position, style } => {
124 let mut world = ctx.world().clone();
125 log::info!(
126 "style at position {position:?} . {style:?} when main is {:?}",
127 world.main()
128 );
129 let cursor = ctx.to_typst_pos(position, &source)?;
130 let root = LinkedNode::new(source.root());
131 let mut leaf = root.leaf_at_compat(cursor)?;
132 log::info!("style at leaf {leaf:?} . {style:?}");
133
134 if !matches!(leaf.kind(), SyntaxKind::Text | SyntaxKind::MathText) {
135 return None;
136 }
137
138 if matches!(leaf.parent_kind(), Some(SyntaxKind::Raw)) {
139 leaf = leaf.parent()?.clone();
140 }
141
142 let mode = Self::mode_at(&source, cursor);
143 if !matches!(
144 mode,
145 Some(InterpretMode::Code | InterpretMode::Markup | InterpretMode::Math)
146 ) {
147 leaf = leaf.parent()?.clone();
148 }
149 let mut mapped_source = source.clone();
150 let (with, offset) = match mode {
151 Some(InterpretMode::Code) => ("context text.font", 8),
152 _ => ("#context text.font", 10),
153 };
154 let start = leaf.range().start;
155 mapped_source.edit(leaf.range(), with);
156
157 let _ = world.map_shadow_by_id(
158 mapped_source.id(),
159 Bytes::new(mapped_source.text().as_bytes().to_vec()),
160 );
161 world.take_db();
162
163 let root = LinkedNode::new(mapped_source.root());
164 let leaf = root.leaf_at_compat(start + offset)?;
165
166 log::info!("style at new_leaf {leaf:?} . {style:?}");
167
168 let mut cursor_styles = analyze_expr(&world, &leaf)
169 .iter()
170 .filter_map(|s| s.1.clone())
171 .collect::<Vec<_>>();
172 cursor_styles.sort_by_key(|x| x.as_slice().len());
173 log::info!("style at styles {cursor_styles:?} . {style:?}");
174 let cursor_style = cursor_styles.into_iter().next_back().unwrap_or_default();
175 let cursor_style = StyleChain::new(&cursor_style);
176
177 log::info!("style at style {cursor_style:?} . {style:?}");
178
179 let style = style
180 .iter()
181 .map(|style| Self::style_at(cursor_style, style))
182 .collect();
183 let _ = world.map_shadow_by_id(
184 mapped_source.id(),
185 Bytes::new(source.text().as_bytes().to_vec()),
186 );
187
188 Some(InteractCodeContextResponse::StyleAt { style })
189 }
190 }));
191 }
192
193 Some(responses)
194 }
195}
196
197impl InteractCodeContextRequest {
198 fn mode_at(source: &Source, pos: usize) -> Option<InterpretMode> {
199 if pos == 0 || pos >= source.text().len() {
201 return Some(InterpretMode::Markup);
202 }
203
204 let root = LinkedNode::new(source.root());
206 Some(interpret_mode_at(root.leaf_at_compat(pos).as_ref()))
207 }
208
209 fn style_at(cursor_style: StyleChain, style: &str) -> Option<JsonValue> {
210 match style {
211 "text.font" => {
212 let font = cursor_style.get_cloned(TextElem::font).into_value();
213 serde_json::to_value(font).ok()
214 }
215 _ => None,
216 }
217 }
218}
219
220fn eval_path_expr(
221 ctx: &mut LocalContext,
222 code: &str,
223 inputs: Dict,
224) -> Option<QueryResult<serde_json::Value>> {
225 let entry = ctx.world().entry_state();
226 let path = if code.starts_with("{") && code.ends_with("}") {
227 let id = entry
228 .select_in_workspace(Path::new("/__path__.typ"))
229 .main()?;
230
231 let inputs = make_sys(&entry, ctx.world().inputs(), inputs);
232 let (inputs, root, dir, name) = match inputs {
233 Some(EvalSysCtx {
234 inputs,
235 root,
236 dir,
237 name,
238 }) => (Some(inputs), Some(root), dir, Some(name)),
239 None => (None, None, None, None),
240 };
241
242 let mut world = ctx.world().task(tinymist_world::TaskInputs {
243 entry: None,
244 inputs,
245 });
246 world.take_db();
248 let _ = world.map_shadow_by_id(id, Bytes::from_string(code.to_owned()));
249
250 tinymist_analysis::upstream::with_vm((&world as &dyn World).track(), |vm| {
251 define_val(vm, "join", Value::Func(join::data().into()));
252 for (key, value) in [("root", root), ("dir", dir), ("name", name)] {
253 if let Some(value) = value {
254 define_val(vm, key, value);
255 }
256 }
257
258 let mut expr = typst::syntax::parse_code(code);
259 let span = Span::from_range(id, 0..code.len());
260 expr.synthesize(span);
261
262 let expr = match expr.cast::<ast::Code>() {
263 Some(v) => v,
264 None => bail!(
265 "code is not a valid code expression: kind={:?}",
266 expr.kind()
267 ),
268 };
269 match expr.eval(vm) {
270 Ok(value) => serde_json::to_value(value).context_ut("failed to serialize path"),
271 Err(e) => {
272 let res =
273 print_diagnostics_to_string(&world, e.iter(), DiagnosticFormat::Human);
274 let err = res.unwrap_or_else(|e| e);
275 bail!("failed to evaluate path expression: {err}")
276 }
277 }
278 })
279 } else {
280 PathPattern::new(code)
281 .substitute(&entry)
282 .context_ut("failed to substitute path pattern")
283 .and_then(|path| {
284 serde_json::to_value(path.deref()).context_ut("failed to serialize path")
285 })
286 };
287 Some(path.into())
288}
289
290#[derive(Debug, Clone, Hash)]
291struct EvalSysCtx {
292 inputs: Arc<LazyHash<Dict>>,
293 root: Value,
294 dir: Option<Value>,
295 name: Value,
296}
297
298#[comemo::memoize]
299fn make_sys(entry: &EntryState, base: Arc<LazyHash<Dict>>, inputs: Dict) -> Option<EvalSysCtx> {
300 let root = entry.root();
301 let main = entry.main();
302
303 log::debug!("Check path {main:?} and root {root:?}");
304
305 let (root, main) = root.zip(main)?;
306
307 if WorkspaceResolver::is_package_file(main) {
309 return None;
310 }
311 let path = main.vpath().resolve(&root)?;
313
314 if path.strip_prefix("/untitled").is_ok() {
316 return None;
317 }
318
319 let path = path.strip_prefix(&root).ok()?;
320 let dir = path.parent();
321 let file_name = path.file_name().unwrap_or_default();
322
323 let root = Value::Str(root.to_string_lossy().into());
324
325 let dir = dir.map(|d| Value::Str(d.to_string_lossy().into()));
326
327 let name = file_name.to_string_lossy();
328 let name = name.as_ref().strip_suffix(".typ").unwrap_or(name.as_ref());
329 let name = Value::Str(name.into());
330
331 let mut dict = base.as_ref().deref().clone();
332 for (key, value) in inputs {
333 dict.insert(key, value);
334 }
335 dict.insert("root".into(), root.clone());
336 if let Some(dir) = &dir {
337 dict.insert("dir".into(), dir.clone());
338 }
339 dict.insert("name".into(), name.clone());
340
341 Some(EvalSysCtx {
342 inputs: Arc::new(LazyHash::new(dict)),
343 root,
344 dir,
345 name,
346 })
347}
348
349fn define_val(vm: &mut Vm, name: &str, value: Value) {
350 let ident = SyntaxNode::leaf(SyntaxKind::Ident, name);
351 vm.define(ident.cast::<ast::Ident>().unwrap(), value);
352}
353
354#[typst_macros::func(title = "Join function")]
355fn join(args: &mut Args) -> SourceResult<Value> {
356 let pos = args.take().to_pos();
357 let mut res = PathBuf::new();
358 for arg in pos {
359 match arg {
360 Value::Str(s) => res.push(s.as_str()),
361 _ => {
362 return Err(eco_format!("join argument is not a string: {arg:?}")).at(args.span);
363 }
364 };
365 }
366 Ok(Value::Str(res.to_string_lossy().into()))
367}
368
369#[derive(Debug, Clone, Serialize, Deserialize)]
371#[serde(untagged)]
372pub enum QueryResult<T> {
373 Success {
375 value: T,
377 },
378 Error {
380 error: EcoString,
382 },
383}
384
385impl<T> QueryResult<T> {
386 pub fn success(value: T) -> Self {
388 Self::Success { value }
389 }
390
391 pub fn error(error: EcoString) -> Self {
393 Self::Error { error }
394 }
395}
396
397impl<T, E: std::error::Error> From<Result<T, E>> for QueryResult<T> {
398 fn from(value: Result<T, E>) -> Self {
399 match value {
400 Ok(value) => QueryResult::success(value),
401 Err(error) => QueryResult::error(eco_format!("{error}")),
402 }
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use typst::foundations::dict;
409
410 use super::*;
411 use crate::tests::*;
412
413 #[test]
414 fn test() {
415 snapshot_testing("code_context_path_at", &|ctx, path| {
416 let patterns = [
417 "$root/$dir/$name",
418 "$root/$name",
419 "$root/assets",
420 "$root/assets/$name",
421 r#"{ join(root, "x", dir, "y", name) }"#,
422 r#"{ join(root, 1) }"#,
423 r#"{ join(roo, 1) }"#,
424 ];
425 let inp = [
426 dict! {
427 "x-path-context" => "vscode-paste",
428 "x-path-input-uri" => "https://huh.io/img.png",
429 "x-path-input-name" => "img.png",
430 },
431 dict! {
432 "x-path-context" => "vscode-paste",
433 "x-path-input-uri" => "https://huh.io/text.md",
434 "x-path-input-name" => "text.md",
435 },
436 ];
437
438 let cases = patterns
439 .iter()
440 .map(|pat| (*pat, inp[0].clone()))
441 .chain(inp.iter().map(|inp| {
442 (
443 r#"{ import "/resolve.typ": resolve; resolve(join, root, dir, name) }"#,
444 inp.clone(),
445 )
446 }));
447
448 let result = cases
449 .map(|(code, inputs)| {
450 let request = InteractCodeContextRequest {
451 path: path.clone(),
452 query: vec![Some(InteractCodeContextQuery::PathAt {
453 code: code.to_string(),
454 inputs: inputs.clone(),
455 })],
456 };
457 json!({ "code": code, "inputs": inputs, "response": request.request(ctx) })
458 })
459 .collect::<Vec<_>>();
460 assert_snapshot!(JsonRepr::new_redacted(result, &REDACT_LOC));
461 });
462 }
463}