tinymist_query/analysis/
post_tyck.rs

1//! Infer more than the principal type of some expression.
2
3use std::collections::HashSet;
4use tinymist_derive::BindTyCtx;
5
6use super::{
7    ArgsTy, Sig, SigChecker, SigShape, SigSurfaceKind, SigTy, Ty, TyCtx, TyCtxMut, TypeBounds,
8    TypeInfo, TypeVar,
9};
10use super::{DynTypeBounds, ParamAttrs, ParamTy, SharedContext, prelude::*};
11use crate::syntax::{ArgClass, SyntaxContext, VarClass, classify_context, classify_context_outer};
12use crate::ty::{BuiltinTy, RecordTy};
13
14/// With given type information, check the type of a literal expression again by
15/// touching the possible related nodes.
16#[typst_macros::time(span = node.span())]
17pub(crate) fn post_type_check(
18    ctx: Arc<SharedContext>,
19    ti: &TypeInfo,
20    node: LinkedNode,
21) -> Option<Ty> {
22    let mut checker = PostTypeChecker::new(ctx, ti);
23    let res = checker.check(&node);
24    checker.simplify(&res?)
25}
26
27#[derive(Default)]
28struct SignatureReceiver {
29    lbs_dedup: HashSet<Ty>,
30    ubs_dedup: HashSet<Ty>,
31    bounds: TypeBounds,
32}
33
34impl SignatureReceiver {
35    fn insert(&mut self, ty: Ty, pol: bool) {
36        crate::log_debug_ct!("post check receive: {ty:?}");
37        if !pol {
38            if self.lbs_dedup.insert(ty.clone()) {
39                self.bounds.lbs.push(ty);
40            }
41        } else if self.ubs_dedup.insert(ty.clone()) {
42            self.bounds.ubs.push(ty);
43        }
44    }
45
46    fn finalize(self) -> Ty {
47        Ty::Let(self.bounds.into())
48    }
49}
50
51fn check_signature<'a>(
52    receiver: &'a mut SignatureReceiver,
53    arg: &'a ArgClass,
54) -> impl FnMut(&mut PostTypeChecker, Sig, &[Interned<ArgsTy>], bool) -> Option<()> + 'a {
55    move |worker, sig, args, pol| {
56        let (sig, _is_partialize) = match sig {
57            Sig::Partialize(sig) => (*sig, true),
58            sig => (sig, false),
59        };
60
61        let SigShape { sig: sig_ins, .. } = sig.shape(worker)?;
62
63        match &arg {
64            ArgClass::Named(n) => {
65                let ident = n.cast::<ast::Ident>()?;
66                let ty = sig_ins.named(&ident.into())?;
67                receiver.insert(ty.clone(), !pol);
68
69                Some(())
70            }
71            ArgClass::Positional {
72                // todo: spreads
73                spreads: _,
74                positional,
75                is_spread,
76            } => {
77                if *is_spread {
78                    return None;
79                }
80
81                // truncate args
82                let bound_pos = args
83                    .iter()
84                    .map(|args| args.positional_params().len())
85                    .sum::<usize>();
86                if let Some(nth) = sig_ins.pos_or_rest(bound_pos + positional) {
87                    receiver.insert(nth, !pol);
88                }
89
90                // names
91                for (name, ty) in sig_ins.named_params() {
92                    let field = ParamTy::new(ty.clone(), name.clone(), ParamAttrs::named());
93                    receiver.insert(Ty::Param(field), !pol);
94                }
95
96                Some(())
97            }
98        }
99    }
100}
101
102pub(crate) struct PostTypeChecker<'a> {
103    ctx: Arc<SharedContext>,
104    pub info: &'a TypeInfo,
105    checked: HashMap<Span, Option<Ty>>,
106    locals: TypeInfo,
107}
108
109impl TyCtx for PostTypeChecker<'_> {
110    fn global_bounds(&self, var: &Interned<TypeVar>, pol: bool) -> Option<DynTypeBounds> {
111        self.info.global_bounds(var, pol)
112    }
113
114    fn local_bind_of(&self, var: &Interned<TypeVar>) -> Option<Ty> {
115        self.locals.local_bind_of(var)
116    }
117}
118
119impl TyCtxMut for PostTypeChecker<'_> {
120    type Snap = <TypeInfo as TyCtxMut>::Snap;
121
122    fn start_scope(&mut self) -> Self::Snap {
123        self.locals.start_scope()
124    }
125
126    fn end_scope(&mut self, snap: Self::Snap) {
127        self.locals.end_scope(snap)
128    }
129
130    fn bind_local(&mut self, var: &Interned<TypeVar>, ty: Ty) {
131        self.locals.bind_local(var, ty);
132    }
133
134    fn type_of_func(&mut self, func: &Func) -> Option<Interned<SigTy>> {
135        Some(self.ctx.type_of_func(func.clone()).type_sig())
136    }
137
138    fn type_of_value(&mut self, val: &Value) -> Ty {
139        self.ctx.type_of_value(val)
140    }
141
142    fn check_module_item(&mut self, _module: TypstFileId, _key: &StrRef) -> Option<Ty> {
143        None
144    }
145}
146
147impl<'a> PostTypeChecker<'a> {
148    pub fn new(ctx: Arc<SharedContext>, info: &'a TypeInfo) -> Self {
149        Self {
150            ctx,
151            info,
152            checked: HashMap::new(),
153            locals: TypeInfo::default(),
154        }
155    }
156
157    fn check(&mut self, node: &LinkedNode) -> Option<Ty> {
158        let span = node.span();
159        if let Some(ty) = self.checked.get(&span) {
160            return ty.clone();
161        }
162        // loop detection
163        self.checked.insert(span, None);
164
165        let ty = self.check_(node);
166        self.checked.insert(span, ty.clone());
167        ty
168    }
169
170    fn simplify(&mut self, ty: &Ty) -> Option<Ty> {
171        Some(self.info.simplify(ty.clone(), false))
172    }
173
174    fn check_(&mut self, node: &LinkedNode) -> Option<Ty> {
175        let context = node.parent()?;
176        crate::log_debug_ct!("post check: {:?}::{:?}", context.kind(), node.kind());
177
178        let context_ty = self.check_context(context, node);
179        let self_ty = if !matches!(node.kind(), SyntaxKind::Label | SyntaxKind::Ref) {
180            self.info.type_of_span(node.span())
181        } else {
182            None
183        };
184
185        let can_penetrate_context = !(matches!(
186            node.kind(),
187            SyntaxKind::Hash | SyntaxKind::ContentBlock | SyntaxKind::CodeBlock
188        ) || matches!(context.kind(), SyntaxKind::FieldAccess) && {
189            let field_access = context.cast::<ast::FieldAccess>()?;
190            field_access.field().span() == node.span()
191        });
192
193        let contextual_self_ty = can_penetrate_context
194            .then(|| self.check_cursor(classify_context(node.clone(), None), context_ty));
195        crate::log_debug_ct!(
196            "post check(res): {:?}::{:?} -> {self_ty:?}, {contextual_self_ty:?}",
197            context.kind(),
198            node.kind(),
199        );
200
201        Ty::union(self_ty, contextual_self_ty.flatten())
202    }
203
204    fn check_or(&mut self, node: &LinkedNode, ty: Option<Ty>) -> Option<Ty> {
205        Ty::union(self.check(node), ty)
206    }
207
208    fn check_cursor(
209        &mut self,
210        cursor: Option<SyntaxContext>,
211        context_ty: Option<Ty>,
212    ) -> Option<Ty> {
213        let Some(cursor) = cursor else {
214            return context_ty;
215        };
216        crate::log_debug_ct!("post check target: {cursor:?}");
217
218        match &cursor {
219            SyntaxContext::Arg {
220                callee,
221                args: _,
222                target,
223                is_set,
224            } => {
225                let callee_ty = self.check_or(callee, context_ty)?;
226                crate::log_debug_ct!(
227                    "post check call target: ({callee_ty:?})::{target:?} is_set: {is_set}"
228                );
229
230                let sig = self.ctx.sig_of_type_or_dyn(self.info, callee_ty, callee)?;
231                crate::log_debug_ct!("post check call sig: {target:?} {sig:?}");
232                let mut resp = SignatureReceiver::default();
233
234                match target {
235                    ArgClass::Named(n) => {
236                        let ident = n.cast::<ast::Ident>()?.into();
237                        let ty = sig.primary().get_named(&ident)?;
238                        // todo: losing docs
239                        resp.insert(ty.ty.clone(), false);
240                    }
241                    ArgClass::Positional {
242                        // todo: spreads
243                        spreads: _,
244                        positional,
245                        is_spread,
246                    } => {
247                        if *is_spread {
248                            return None;
249                        }
250
251                        // truncate args
252                        let shift = sig.param_shift();
253                        let nth = sig
254                            .primary()
255                            .get_pos(shift + positional)
256                            .or_else(|| sig.primary().rest());
257                        if let Some(nth) = nth {
258                            resp.insert(Ty::Param(nth.clone()), false);
259                        }
260
261                        // names
262                        for field in sig.primary().named() {
263                            if *is_set && !field.attrs.settable {
264                                continue;
265                            }
266
267                            resp.insert(Ty::Param(field.clone()), false);
268                        }
269                    }
270                }
271
272                crate::log_debug_ct!("post check target iterated: {:?}", resp.bounds);
273                Some(resp.finalize())
274            }
275            SyntaxContext::Element { container, target } => {
276                // The `Array` / `Dict` syntax node is often wrapped by a `Parenthesized`
277                // expression, which is where contextual typing (e.g. let-binding type) applies.
278                // Use the parenthesized container type when available so element types can be
279                // inferred from outer constraints.
280                // todo: however, user may continue edit a parenthesized expression to become
281                // array or dict in typst. We should ensure that we can identify such cases
282                // correctly.
283                let container_expr = match container.kind() {
284                    SyntaxKind::Array | SyntaxKind::Dict => container
285                        .parent()
286                        .cloned()
287                        .filter(|p| p.kind() == SyntaxKind::Parenthesized)
288                        .unwrap_or_else(|| container.clone()),
289                    _ => container.clone(),
290                };
291                let container_ty = self.check_or(&container_expr, context_ty)?;
292                crate::log_debug_ct!("post check element target: ({container_ty:?})::{target:?}");
293
294                let mut resp = SignatureReceiver::default();
295
296                self.check_element_of(
297                    &container_ty,
298                    false,
299                    container,
300                    &mut check_signature(&mut resp, target),
301                );
302
303                crate::log_debug_ct!("post check target iterated: {:?}", resp.bounds);
304                Some(resp.finalize())
305            }
306            SyntaxContext::Paren {
307                container,
308                is_before,
309            } => {
310                let container_ty = self.check_or(container, context_ty)?;
311                crate::log_debug_ct!("post check paren target: {container_ty:?}::{is_before:?}");
312
313                let mut resp = SignatureReceiver::default();
314                // todo: this is legal, but it makes it sometimes complete itself.
315                // e.g. completing `""` on `let x = ("|")`
316                resp.bounds.lbs.push(container_ty.clone());
317
318                let target = ArgClass::first_positional();
319                self.check_element_of(
320                    &container_ty,
321                    false,
322                    container,
323                    &mut check_signature(&mut resp, &target),
324                );
325
326                crate::log_debug_ct!("post check target iterated: {:?}", resp.bounds);
327                Some(resp.finalize())
328            }
329            SyntaxContext::ImportPath(..) | SyntaxContext::IncludePath(..) => {
330                Some(Ty::Builtin(BuiltinTy::Path(crate::ty::PathKind::Source {
331                    allow_package: true,
332                })))
333            }
334            SyntaxContext::VarAccess(VarClass::Ident(node))
335            | SyntaxContext::VarAccess(VarClass::FieldAccess(node))
336            | SyntaxContext::VarAccess(VarClass::DotAccess(node))
337            | SyntaxContext::Label { node, .. }
338            | SyntaxContext::Ref { node, .. }
339            | SyntaxContext::At { node, .. }
340            | SyntaxContext::Normal(node) => {
341                let label_or_ref_ty = match cursor {
342                    SyntaxContext::Label { is_error: true, .. } => {
343                        Some(Ty::Builtin(BuiltinTy::Label))
344                    }
345                    SyntaxContext::Ref {
346                        suffix_colon: true, ..
347                    } => Some(Ty::Builtin(BuiltinTy::RefLabel)),
348                    _ => None,
349                };
350                let ty = self.check_or(node, context_ty);
351                crate::log_debug_ct!("post check target normal: {ty:?} {label_or_ref_ty:?}");
352                ty.or(label_or_ref_ty)
353            }
354        }
355    }
356
357    /// Checks the context of a node and returns the type of the context.
358    fn check_context(&mut self, context: &LinkedNode, node: &LinkedNode) -> Option<Ty> {
359        match context.kind() {
360            SyntaxKind::LetBinding => {
361                let let_binding = context.cast::<ast::LetBinding>()?;
362                let let_init = let_binding.init()?;
363                let let_init_node = context.find(let_init.span())?;
364                let_init_node.find(node.span())?;
365
366                match let_binding.kind() {
367                    ast::LetBindingKind::Closure(_c) => None,
368                    ast::LetBindingKind::Normal(pattern) => self.check_let_pattern(pattern),
369                }
370            }
371            SyntaxKind::Args => self.check_cursor(
372                // todo: not well behaved
373                classify_context_outer(context.clone(), node.clone()),
374                None,
375            ),
376            // todo: constraint node
377            SyntaxKind::Named => self.check_cursor(classify_context(context.clone(), None), None),
378            _ => None,
379        }
380    }
381
382    /// Checks left-side of a let expression `let lhs = rhs` for the `rhs`.
383    fn check_let_pattern(&mut self, pattern: ast::Pattern) -> Option<Ty> {
384        match pattern {
385            ast::Pattern::Placeholder(_) => None,
386            ast::Pattern::Normal(ast::Expr::Ident(ident)) => self.info.type_of_span(ident.span()),
387            ast::Pattern::Normal(..) => None,
388            ast::Pattern::Parenthesized(expr) => {
389                self.check_let_pattern(expr.expr().to_untyped().cast()?)
390            }
391            ast::Pattern::Destructuring(expr) => self.check_let_destruct(expr),
392        }
393    }
394
395    /// Checks left-side of a let expression `let lhs = rhs` for the `rhs`.
396    fn check_let_destruct(&mut self, destructuring: ast::Destructuring) -> Option<Ty> {
397        let mut pos = vec![];
398        let mut named = vec![];
399        for item in destructuring.items() {
400            match item {
401                ast::DestructuringItem::Pattern(pat) => {
402                    pos.push(self.check_let_pattern(pat).unwrap_or(Ty::Any));
403                }
404                ast::DestructuringItem::Named(named_item) => {
405                    let key = Interned::new_str(named_item.name().get().as_str());
406                    let ty = self
407                        .check_let_pattern(named_item.pattern())
408                        .unwrap_or(Ty::Any);
409                    named.push((key, ty));
410                }
411                // `rest` in `(a, b, ..rest) = t` is be ignored because we perform checking for
412                // `t``.
413                ast::DestructuringItem::Spread(..) => {}
414            }
415        }
416
417        let tuple_ty = (!pos.is_empty()).then(|| Ty::Tuple(pos.into()));
418        let dict_ty = (!named.is_empty()).then(|| Ty::Dict(RecordTy::new(named)));
419        Ty::union(tuple_ty, dict_ty)
420    }
421
422    fn check_element_of<T>(&mut self, ty: &Ty, pol: bool, context: &LinkedNode, checker: &mut T)
423    where
424        T: PostSigChecker,
425    {
426        let mut checker = PostSigCheckWorker(self, checker);
427        ty.sig_surface(pol, sig_context_of(context), &mut checker)
428    }
429}
430
431trait PostSigChecker {
432    fn check(
433        &mut self,
434        checker: &mut PostTypeChecker,
435        sig: Sig,
436        args: &[Interned<ArgsTy>],
437        pol: bool,
438    ) -> Option<()>;
439}
440
441impl<T> PostSigChecker for T
442where
443    T: FnMut(&mut PostTypeChecker, Sig, &[Interned<ArgsTy>], bool) -> Option<()>,
444{
445    fn check(
446        &mut self,
447        checker: &mut PostTypeChecker,
448        sig: Sig,
449        args: &[Interned<ArgsTy>],
450        pol: bool,
451    ) -> Option<()> {
452        self(checker, sig, args, pol)
453    }
454}
455
456#[derive(BindTyCtx)]
457#[bind(0)]
458struct PostSigCheckWorker<'x, 'a, T>(&'x mut PostTypeChecker<'a>, &'x mut T);
459
460impl<T: PostSigChecker> SigChecker for PostSigCheckWorker<'_, '_, T> {
461    fn check(
462        &mut self,
463        sig: Sig,
464        args: &mut crate::analysis::SigCheckContext,
465        pol: bool,
466    ) -> Option<()> {
467        self.1.check(self.0, sig, &args.args, pol)
468    }
469}
470
471fn sig_context_of(context: &LinkedNode) -> SigSurfaceKind {
472    match context.kind() {
473        SyntaxKind::Parenthesized => SigSurfaceKind::ArrayOrDict,
474        SyntaxKind::Array => {
475            let arr = context.cast::<ast::Array>();
476            if arr.is_some_and(|arr| arr.items().next().is_some()) {
477                SigSurfaceKind::Array
478            } else {
479                SigSurfaceKind::ArrayOrDict
480            }
481        }
482        SyntaxKind::Dict => SigSurfaceKind::Dict,
483        _ => SigSurfaceKind::Array,
484    }
485}