tinymist_query/analysis/
post_tyck.rs

1//! Infer more than the principal type of some expression.
2
3use std::collections::HashSet;
4use tinymist_derive::BindTyCtx;
5
6use super::{
7    ArgsTy, Sig, SigChecker, SigShape, SigSurfaceKind, SigTy, Ty, TyCtx, TyCtxMut, TypeBounds,
8    TypeInfo, TypeVar,
9};
10use super::{DynTypeBounds, ParamAttrs, ParamTy, SharedContext, prelude::*};
11use crate::syntax::{ArgClass, SyntaxContext, VarClass, classify_context, classify_context_outer};
12use crate::ty::BuiltinTy;
13
14/// With given type information, check the type of a literal expression again by
15/// touching the possible related nodes.
16#[typst_macros::time(span = node.span())]
17pub(crate) fn post_type_check(
18    ctx: Arc<SharedContext>,
19    ti: &TypeInfo,
20    node: LinkedNode,
21) -> Option<Ty> {
22    let mut checker = PostTypeChecker::new(ctx, ti);
23    let res = checker.check(&node);
24    checker.simplify(&res?)
25}
26
27#[derive(Default)]
28struct SignatureReceiver {
29    lbs_dedup: HashSet<Ty>,
30    ubs_dedup: HashSet<Ty>,
31    bounds: TypeBounds,
32}
33
34impl SignatureReceiver {
35    fn insert(&mut self, ty: Ty, pol: bool) {
36        crate::log_debug_ct!("post check receive: {ty:?}");
37        if !pol {
38            if self.lbs_dedup.insert(ty.clone()) {
39                self.bounds.lbs.push(ty);
40            }
41        } else if self.ubs_dedup.insert(ty.clone()) {
42            self.bounds.ubs.push(ty);
43        }
44    }
45
46    fn finalize(self) -> Ty {
47        Ty::Let(self.bounds.into())
48    }
49}
50
51fn check_signature<'a>(
52    receiver: &'a mut SignatureReceiver,
53    arg: &'a ArgClass,
54) -> impl FnMut(&mut PostTypeChecker, Sig, &[Interned<ArgsTy>], bool) -> Option<()> + 'a {
55    move |worker, sig, args, pol| {
56        let (sig, _is_partialize) = match sig {
57            Sig::Partialize(sig) => (*sig, true),
58            sig => (sig, false),
59        };
60
61        let SigShape { sig: sig_ins, .. } = sig.shape(worker)?;
62
63        match &arg {
64            ArgClass::Named(n) => {
65                let ident = n.cast::<ast::Ident>()?;
66                let ty = sig_ins.named(&ident.into())?;
67                receiver.insert(ty.clone(), !pol);
68
69                Some(())
70            }
71            ArgClass::Positional {
72                // todo: spreads
73                spreads: _,
74                positional,
75                is_spread,
76            } => {
77                if *is_spread {
78                    return None;
79                }
80
81                // truncate args
82                let bound_pos = args
83                    .iter()
84                    .map(|args| args.positional_params().len())
85                    .sum::<usize>();
86                if let Some(nth) = sig_ins.pos_or_rest(bound_pos + positional) {
87                    receiver.insert(nth, !pol);
88                }
89
90                // names
91                for (name, ty) in sig_ins.named_params() {
92                    let field = ParamTy::new(ty.clone(), name.clone(), ParamAttrs::named());
93                    receiver.insert(Ty::Param(field), !pol);
94                }
95
96                Some(())
97            }
98        }
99    }
100}
101
102pub(crate) struct PostTypeChecker<'a> {
103    ctx: Arc<SharedContext>,
104    pub info: &'a TypeInfo,
105    checked: HashMap<Span, Option<Ty>>,
106    locals: TypeInfo,
107}
108
109impl TyCtx for PostTypeChecker<'_> {
110    fn global_bounds(&self, var: &Interned<TypeVar>, pol: bool) -> Option<DynTypeBounds> {
111        self.info.global_bounds(var, pol)
112    }
113
114    fn local_bind_of(&self, var: &Interned<TypeVar>) -> Option<Ty> {
115        self.locals.local_bind_of(var)
116    }
117}
118
119impl TyCtxMut for PostTypeChecker<'_> {
120    type Snap = <TypeInfo as TyCtxMut>::Snap;
121
122    fn start_scope(&mut self) -> Self::Snap {
123        self.locals.start_scope()
124    }
125
126    fn end_scope(&mut self, snap: Self::Snap) {
127        self.locals.end_scope(snap)
128    }
129
130    fn bind_local(&mut self, var: &Interned<TypeVar>, ty: Ty) {
131        self.locals.bind_local(var, ty);
132    }
133
134    fn type_of_func(&mut self, func: &Func) -> Option<Interned<SigTy>> {
135        Some(self.ctx.type_of_func(func.clone()).type_sig())
136    }
137
138    fn type_of_value(&mut self, val: &Value) -> Ty {
139        self.ctx.type_of_value(val)
140    }
141
142    fn check_module_item(&mut self, _module: TypstFileId, _key: &StrRef) -> Option<Ty> {
143        None
144    }
145}
146
147impl<'a> PostTypeChecker<'a> {
148    pub fn new(ctx: Arc<SharedContext>, info: &'a TypeInfo) -> Self {
149        Self {
150            ctx,
151            info,
152            checked: HashMap::new(),
153            locals: TypeInfo::default(),
154        }
155    }
156
157    fn check(&mut self, node: &LinkedNode) -> Option<Ty> {
158        let span = node.span();
159        if let Some(ty) = self.checked.get(&span) {
160            return ty.clone();
161        }
162        // loop detection
163        self.checked.insert(span, None);
164
165        let ty = self.check_(node);
166        self.checked.insert(span, ty.clone());
167        ty
168    }
169
170    fn simplify(&mut self, ty: &Ty) -> Option<Ty> {
171        Some(self.info.simplify(ty.clone(), false))
172    }
173
174    fn check_(&mut self, node: &LinkedNode) -> Option<Ty> {
175        let context = node.parent()?;
176        crate::log_debug_ct!("post check: {:?}::{:?}", context.kind(), node.kind());
177
178        let context_ty = self.check_context(context, node);
179        let self_ty = if !matches!(node.kind(), SyntaxKind::Label | SyntaxKind::Ref) {
180            self.info.type_of_span(node.span())
181        } else {
182            None
183        };
184
185        let can_penetrate_context = !(matches!(
186            node.kind(),
187            SyntaxKind::Hash | SyntaxKind::ContentBlock | SyntaxKind::CodeBlock
188        ) || matches!(context.kind(), SyntaxKind::FieldAccess) && {
189            let field_access = context.cast::<ast::FieldAccess>()?;
190            field_access.field().span() == node.span()
191        });
192
193        let contextual_self_ty = can_penetrate_context
194            .then(|| self.check_cursor(classify_context(node.clone(), None), context_ty));
195        crate::log_debug_ct!(
196            "post check(res): {:?}::{:?} -> {self_ty:?}, {contextual_self_ty:?}",
197            context.kind(),
198            node.kind(),
199        );
200
201        Ty::union(self_ty, contextual_self_ty.flatten())
202    }
203
204    fn check_or(&mut self, node: &LinkedNode, ty: Option<Ty>) -> Option<Ty> {
205        Ty::union(self.check(node), ty)
206    }
207
208    fn check_cursor(
209        &mut self,
210        cursor: Option<SyntaxContext>,
211        context_ty: Option<Ty>,
212    ) -> Option<Ty> {
213        let Some(cursor) = cursor else {
214            return context_ty;
215        };
216        crate::log_debug_ct!("post check target: {cursor:?}");
217
218        match &cursor {
219            SyntaxContext::Arg {
220                callee,
221                args: _,
222                target,
223                is_set,
224            } => {
225                let callee_ty = self.check_or(callee, context_ty)?;
226                crate::log_debug_ct!(
227                    "post check call target: ({callee_ty:?})::{target:?} is_set: {is_set}"
228                );
229
230                let sig = self.ctx.sig_of_type_or_dyn(self.info, callee_ty, callee)?;
231                crate::log_debug_ct!("post check call sig: {target:?} {sig:?}");
232                let mut resp = SignatureReceiver::default();
233
234                match target {
235                    ArgClass::Named(n) => {
236                        let ident = n.cast::<ast::Ident>()?.into();
237                        let ty = sig.primary().get_named(&ident)?;
238                        // todo: losing docs
239                        resp.insert(ty.ty.clone(), false);
240                    }
241                    ArgClass::Positional {
242                        // todo: spreads
243                        spreads: _,
244                        positional,
245                        is_spread,
246                    } => {
247                        if *is_spread {
248                            return None;
249                        }
250
251                        // truncate args
252                        let shift = sig.param_shift();
253                        let nth = sig
254                            .primary()
255                            .get_pos(shift + positional)
256                            .or_else(|| sig.primary().rest());
257                        if let Some(nth) = nth {
258                            resp.insert(Ty::Param(nth.clone()), false);
259                        }
260
261                        // names
262                        for field in sig.primary().named() {
263                            if *is_set && !field.attrs.settable {
264                                continue;
265                            }
266
267                            resp.insert(Ty::Param(field.clone()), false);
268                        }
269                    }
270                }
271
272                crate::log_debug_ct!("post check target iterated: {:?}", resp.bounds);
273                Some(resp.finalize())
274            }
275            SyntaxContext::Element { container, target } => {
276                let container_ty = self.check_or(container, context_ty)?;
277                crate::log_debug_ct!("post check element target: ({container_ty:?})::{target:?}");
278
279                let mut resp = SignatureReceiver::default();
280
281                self.check_element_of(
282                    &container_ty,
283                    false,
284                    container,
285                    &mut check_signature(&mut resp, target),
286                );
287
288                crate::log_debug_ct!("post check target iterated: {:?}", resp.bounds);
289                Some(resp.finalize())
290            }
291            SyntaxContext::Paren {
292                container,
293                is_before,
294            } => {
295                let container_ty = self.check_or(container, context_ty)?;
296                crate::log_debug_ct!("post check paren target: {container_ty:?}::{is_before:?}");
297
298                let mut resp = SignatureReceiver::default();
299                // todo: this is legal, but it makes it sometimes complete itself.
300                // e.g. completing `""` on `let x = ("|")`
301                resp.bounds.lbs.push(container_ty.clone());
302
303                let target = ArgClass::first_positional();
304                self.check_element_of(
305                    &container_ty,
306                    false,
307                    container,
308                    &mut check_signature(&mut resp, &target),
309                );
310
311                crate::log_debug_ct!("post check target iterated: {:?}", resp.bounds);
312                Some(resp.finalize())
313            }
314            SyntaxContext::ImportPath(..) | SyntaxContext::IncludePath(..) => {
315                Some(Ty::Builtin(BuiltinTy::Path(crate::ty::PathKind::Source {
316                    allow_package: true,
317                })))
318            }
319            SyntaxContext::VarAccess(VarClass::Ident(node))
320            | SyntaxContext::VarAccess(VarClass::FieldAccess(node))
321            | SyntaxContext::VarAccess(VarClass::DotAccess(node))
322            | SyntaxContext::Label { node, .. }
323            | SyntaxContext::Ref { node, .. }
324            | SyntaxContext::Normal(node) => {
325                let label_or_ref_ty = match cursor {
326                    SyntaxContext::Label { is_error: true, .. } => {
327                        Some(Ty::Builtin(BuiltinTy::Label))
328                    }
329                    SyntaxContext::Ref {
330                        suffix_colon: true, ..
331                    } => Some(Ty::Builtin(BuiltinTy::RefLabel)),
332                    _ => None,
333                };
334                let ty = self.check_or(node, context_ty);
335                crate::log_debug_ct!("post check target normal: {ty:?} {label_or_ref_ty:?}");
336                ty.or(label_or_ref_ty)
337            }
338        }
339    }
340
341    fn check_context(&mut self, context: &LinkedNode, node: &LinkedNode) -> Option<Ty> {
342        match context.kind() {
343            SyntaxKind::LetBinding => {
344                let let_binding = context.cast::<ast::LetBinding>()?;
345                let let_init = let_binding.init()?;
346                if let_init.span() != node.span() {
347                    return None;
348                }
349
350                match let_binding.kind() {
351                    ast::LetBindingKind::Closure(_c) => None,
352                    ast::LetBindingKind::Normal(pattern) => {
353                        self.destruct_let(pattern, node.clone())
354                    }
355                }
356            }
357            SyntaxKind::Args => self.check_cursor(
358                // todo: not well behaved
359                classify_context_outer(context.clone(), node.clone()),
360                None,
361            ),
362            // todo: constraint node
363            SyntaxKind::Named => self.check_cursor(classify_context(context.clone(), None), None),
364            _ => None,
365        }
366    }
367
368    fn destruct_let(&mut self, pattern: ast::Pattern, node: LinkedNode) -> Option<Ty> {
369        match pattern {
370            ast::Pattern::Placeholder(_) => None,
371            ast::Pattern::Normal(n) => {
372                let ast::Expr::Ident(ident) = n else {
373                    return None;
374                };
375                self.info.type_of_span(ident.span())
376            }
377            ast::Pattern::Parenthesized(paren_expr) => {
378                self.destruct_let(paren_expr.expr().to_untyped().cast()?, node)
379            }
380            // todo: pattern matching
381            ast::Pattern::Destructuring(_d) => {
382                let _ = node;
383                None
384            }
385        }
386    }
387
388    fn check_element_of<T>(&mut self, ty: &Ty, pol: bool, context: &LinkedNode, checker: &mut T)
389    where
390        T: PostSigChecker,
391    {
392        let mut checker = PostSigCheckWorker(self, checker);
393        ty.sig_surface(pol, sig_context_of(context), &mut checker)
394    }
395}
396
397trait PostSigChecker {
398    fn check(
399        &mut self,
400        checker: &mut PostTypeChecker,
401        sig: Sig,
402        args: &[Interned<ArgsTy>],
403        pol: bool,
404    ) -> Option<()>;
405}
406
407impl<T> PostSigChecker for T
408where
409    T: FnMut(&mut PostTypeChecker, Sig, &[Interned<ArgsTy>], bool) -> Option<()>,
410{
411    fn check(
412        &mut self,
413        checker: &mut PostTypeChecker,
414        sig: Sig,
415        args: &[Interned<ArgsTy>],
416        pol: bool,
417    ) -> Option<()> {
418        self(checker, sig, args, pol)
419    }
420}
421
422#[derive(BindTyCtx)]
423#[bind(0)]
424struct PostSigCheckWorker<'x, 'a, T>(&'x mut PostTypeChecker<'a>, &'x mut T);
425
426impl<T: PostSigChecker> SigChecker for PostSigCheckWorker<'_, '_, T> {
427    fn check(
428        &mut self,
429        sig: Sig,
430        args: &mut crate::analysis::SigCheckContext,
431        pol: bool,
432    ) -> Option<()> {
433        self.1.check(self.0, sig, &args.args, pol)
434    }
435}
436
437fn sig_context_of(context: &LinkedNode) -> SigSurfaceKind {
438    match context.kind() {
439        SyntaxKind::Parenthesized => SigSurfaceKind::ArrayOrDict,
440        SyntaxKind::Array => {
441            let arr = context.cast::<ast::Array>();
442            if arr.is_some_and(|arr| arr.items().next().is_some()) {
443                SigSurfaceKind::Array
444            } else {
445                SigSurfaceKind::ArrayOrDict
446            }
447        }
448        SyntaxKind::Dict => SigSurfaceKind::Dict,
449        _ => SigSurfaceKind::Array,
450    }
451}