1use std::collections::HashSet;
4use tinymist_derive::BindTyCtx;
5
6use super::{
7 ArgsTy, Sig, SigChecker, SigShape, SigSurfaceKind, SigTy, Ty, TyCtx, TyCtxMut, TypeBounds,
8 TypeInfo, TypeVar,
9};
10use super::{DynTypeBounds, ParamAttrs, ParamTy, SharedContext, prelude::*};
11use crate::syntax::{ArgClass, SyntaxContext, VarClass, classify_context, classify_context_outer};
12use crate::ty::BuiltinTy;
13
14#[typst_macros::time(span = node.span())]
17pub(crate) fn post_type_check(
18 ctx: Arc<SharedContext>,
19 ti: &TypeInfo,
20 node: LinkedNode,
21) -> Option<Ty> {
22 let mut checker = PostTypeChecker::new(ctx, ti);
23 let res = checker.check(&node);
24 checker.simplify(&res?)
25}
26
27#[derive(Default)]
28struct SignatureReceiver {
29 lbs_dedup: HashSet<Ty>,
30 ubs_dedup: HashSet<Ty>,
31 bounds: TypeBounds,
32}
33
34impl SignatureReceiver {
35 fn insert(&mut self, ty: Ty, pol: bool) {
36 crate::log_debug_ct!("post check receive: {ty:?}");
37 if !pol {
38 if self.lbs_dedup.insert(ty.clone()) {
39 self.bounds.lbs.push(ty);
40 }
41 } else if self.ubs_dedup.insert(ty.clone()) {
42 self.bounds.ubs.push(ty);
43 }
44 }
45
46 fn finalize(self) -> Ty {
47 Ty::Let(self.bounds.into())
48 }
49}
50
51fn check_signature<'a>(
52 receiver: &'a mut SignatureReceiver,
53 arg: &'a ArgClass,
54) -> impl FnMut(&mut PostTypeChecker, Sig, &[Interned<ArgsTy>], bool) -> Option<()> + 'a {
55 move |worker, sig, args, pol| {
56 let (sig, _is_partialize) = match sig {
57 Sig::Partialize(sig) => (*sig, true),
58 sig => (sig, false),
59 };
60
61 let SigShape { sig: sig_ins, .. } = sig.shape(worker)?;
62
63 match &arg {
64 ArgClass::Named(n) => {
65 let ident = n.cast::<ast::Ident>()?;
66 let ty = sig_ins.named(&ident.into())?;
67 receiver.insert(ty.clone(), !pol);
68
69 Some(())
70 }
71 ArgClass::Positional {
72 spreads: _,
74 positional,
75 is_spread,
76 } => {
77 if *is_spread {
78 return None;
79 }
80
81 let bound_pos = args
83 .iter()
84 .map(|args| args.positional_params().len())
85 .sum::<usize>();
86 if let Some(nth) = sig_ins.pos_or_rest(bound_pos + positional) {
87 receiver.insert(nth, !pol);
88 }
89
90 for (name, ty) in sig_ins.named_params() {
92 let field = ParamTy::new(ty.clone(), name.clone(), ParamAttrs::named());
93 receiver.insert(Ty::Param(field), !pol);
94 }
95
96 Some(())
97 }
98 }
99 }
100}
101
102pub(crate) struct PostTypeChecker<'a> {
103 ctx: Arc<SharedContext>,
104 pub info: &'a TypeInfo,
105 checked: HashMap<Span, Option<Ty>>,
106 locals: TypeInfo,
107}
108
109impl TyCtx for PostTypeChecker<'_> {
110 fn global_bounds(&self, var: &Interned<TypeVar>, pol: bool) -> Option<DynTypeBounds> {
111 self.info.global_bounds(var, pol)
112 }
113
114 fn local_bind_of(&self, var: &Interned<TypeVar>) -> Option<Ty> {
115 self.locals.local_bind_of(var)
116 }
117}
118
119impl TyCtxMut for PostTypeChecker<'_> {
120 type Snap = <TypeInfo as TyCtxMut>::Snap;
121
122 fn start_scope(&mut self) -> Self::Snap {
123 self.locals.start_scope()
124 }
125
126 fn end_scope(&mut self, snap: Self::Snap) {
127 self.locals.end_scope(snap)
128 }
129
130 fn bind_local(&mut self, var: &Interned<TypeVar>, ty: Ty) {
131 self.locals.bind_local(var, ty);
132 }
133
134 fn type_of_func(&mut self, func: &Func) -> Option<Interned<SigTy>> {
135 Some(self.ctx.type_of_func(func.clone()).type_sig())
136 }
137
138 fn type_of_value(&mut self, val: &Value) -> Ty {
139 self.ctx.type_of_value(val)
140 }
141
142 fn check_module_item(&mut self, _module: TypstFileId, _key: &StrRef) -> Option<Ty> {
143 None
144 }
145}
146
147impl<'a> PostTypeChecker<'a> {
148 pub fn new(ctx: Arc<SharedContext>, info: &'a TypeInfo) -> Self {
149 Self {
150 ctx,
151 info,
152 checked: HashMap::new(),
153 locals: TypeInfo::default(),
154 }
155 }
156
157 fn check(&mut self, node: &LinkedNode) -> Option<Ty> {
158 let span = node.span();
159 if let Some(ty) = self.checked.get(&span) {
160 return ty.clone();
161 }
162 self.checked.insert(span, None);
164
165 let ty = self.check_(node);
166 self.checked.insert(span, ty.clone());
167 ty
168 }
169
170 fn simplify(&mut self, ty: &Ty) -> Option<Ty> {
171 Some(self.info.simplify(ty.clone(), false))
172 }
173
174 fn check_(&mut self, node: &LinkedNode) -> Option<Ty> {
175 let context = node.parent()?;
176 crate::log_debug_ct!("post check: {:?}::{:?}", context.kind(), node.kind());
177
178 let context_ty = self.check_context(context, node);
179 let self_ty = if !matches!(node.kind(), SyntaxKind::Label | SyntaxKind::Ref) {
180 self.info.type_of_span(node.span())
181 } else {
182 None
183 };
184
185 let can_penetrate_context = !(matches!(
186 node.kind(),
187 SyntaxKind::Hash | SyntaxKind::ContentBlock | SyntaxKind::CodeBlock
188 ) || matches!(context.kind(), SyntaxKind::FieldAccess) && {
189 let field_access = context.cast::<ast::FieldAccess>()?;
190 field_access.field().span() == node.span()
191 });
192
193 let contextual_self_ty = can_penetrate_context
194 .then(|| self.check_cursor(classify_context(node.clone(), None), context_ty));
195 crate::log_debug_ct!(
196 "post check(res): {:?}::{:?} -> {self_ty:?}, {contextual_self_ty:?}",
197 context.kind(),
198 node.kind(),
199 );
200
201 Ty::union(self_ty, contextual_self_ty.flatten())
202 }
203
204 fn check_or(&mut self, node: &LinkedNode, ty: Option<Ty>) -> Option<Ty> {
205 Ty::union(self.check(node), ty)
206 }
207
208 fn check_cursor(
209 &mut self,
210 cursor: Option<SyntaxContext>,
211 context_ty: Option<Ty>,
212 ) -> Option<Ty> {
213 let Some(cursor) = cursor else {
214 return context_ty;
215 };
216 crate::log_debug_ct!("post check target: {cursor:?}");
217
218 match &cursor {
219 SyntaxContext::Arg {
220 callee,
221 args: _,
222 target,
223 is_set,
224 } => {
225 let callee_ty = self.check_or(callee, context_ty)?;
226 crate::log_debug_ct!(
227 "post check call target: ({callee_ty:?})::{target:?} is_set: {is_set}"
228 );
229
230 let sig = self.ctx.sig_of_type_or_dyn(self.info, callee_ty, callee)?;
231 crate::log_debug_ct!("post check call sig: {target:?} {sig:?}");
232 let mut resp = SignatureReceiver::default();
233
234 match target {
235 ArgClass::Named(n) => {
236 let ident = n.cast::<ast::Ident>()?.into();
237 let ty = sig.primary().get_named(&ident)?;
238 resp.insert(ty.ty.clone(), false);
240 }
241 ArgClass::Positional {
242 spreads: _,
244 positional,
245 is_spread,
246 } => {
247 if *is_spread {
248 return None;
249 }
250
251 let shift = sig.param_shift();
253 let nth = sig
254 .primary()
255 .get_pos(shift + positional)
256 .or_else(|| sig.primary().rest());
257 if let Some(nth) = nth {
258 resp.insert(Ty::Param(nth.clone()), false);
259 }
260
261 for field in sig.primary().named() {
263 if *is_set && !field.attrs.settable {
264 continue;
265 }
266
267 resp.insert(Ty::Param(field.clone()), false);
268 }
269 }
270 }
271
272 crate::log_debug_ct!("post check target iterated: {:?}", resp.bounds);
273 Some(resp.finalize())
274 }
275 SyntaxContext::Element { container, target } => {
276 let container_ty = self.check_or(container, context_ty)?;
277 crate::log_debug_ct!("post check element target: ({container_ty:?})::{target:?}");
278
279 let mut resp = SignatureReceiver::default();
280
281 self.check_element_of(
282 &container_ty,
283 false,
284 container,
285 &mut check_signature(&mut resp, target),
286 );
287
288 crate::log_debug_ct!("post check target iterated: {:?}", resp.bounds);
289 Some(resp.finalize())
290 }
291 SyntaxContext::Paren {
292 container,
293 is_before,
294 } => {
295 let container_ty = self.check_or(container, context_ty)?;
296 crate::log_debug_ct!("post check paren target: {container_ty:?}::{is_before:?}");
297
298 let mut resp = SignatureReceiver::default();
299 resp.bounds.lbs.push(container_ty.clone());
302
303 let target = ArgClass::first_positional();
304 self.check_element_of(
305 &container_ty,
306 false,
307 container,
308 &mut check_signature(&mut resp, &target),
309 );
310
311 crate::log_debug_ct!("post check target iterated: {:?}", resp.bounds);
312 Some(resp.finalize())
313 }
314 SyntaxContext::ImportPath(..) | SyntaxContext::IncludePath(..) => {
315 Some(Ty::Builtin(BuiltinTy::Path(crate::ty::PathKind::Source {
316 allow_package: true,
317 })))
318 }
319 SyntaxContext::VarAccess(VarClass::Ident(node))
320 | SyntaxContext::VarAccess(VarClass::FieldAccess(node))
321 | SyntaxContext::VarAccess(VarClass::DotAccess(node))
322 | SyntaxContext::Label { node, .. }
323 | SyntaxContext::Ref { node, .. }
324 | SyntaxContext::Normal(node) => {
325 let label_or_ref_ty = match cursor {
326 SyntaxContext::Label { is_error: true, .. } => {
327 Some(Ty::Builtin(BuiltinTy::Label))
328 }
329 SyntaxContext::Ref {
330 suffix_colon: true, ..
331 } => Some(Ty::Builtin(BuiltinTy::RefLabel)),
332 _ => None,
333 };
334 let ty = self.check_or(node, context_ty);
335 crate::log_debug_ct!("post check target normal: {ty:?} {label_or_ref_ty:?}");
336 ty.or(label_or_ref_ty)
337 }
338 }
339 }
340
341 fn check_context(&mut self, context: &LinkedNode, node: &LinkedNode) -> Option<Ty> {
342 match context.kind() {
343 SyntaxKind::LetBinding => {
344 let let_binding = context.cast::<ast::LetBinding>()?;
345 let let_init = let_binding.init()?;
346 if let_init.span() != node.span() {
347 return None;
348 }
349
350 match let_binding.kind() {
351 ast::LetBindingKind::Closure(_c) => None,
352 ast::LetBindingKind::Normal(pattern) => {
353 self.destruct_let(pattern, node.clone())
354 }
355 }
356 }
357 SyntaxKind::Args => self.check_cursor(
358 classify_context_outer(context.clone(), node.clone()),
360 None,
361 ),
362 SyntaxKind::Named => self.check_cursor(classify_context(context.clone(), None), None),
364 _ => None,
365 }
366 }
367
368 fn destruct_let(&mut self, pattern: ast::Pattern, node: LinkedNode) -> Option<Ty> {
369 match pattern {
370 ast::Pattern::Placeholder(_) => None,
371 ast::Pattern::Normal(n) => {
372 let ast::Expr::Ident(ident) = n else {
373 return None;
374 };
375 self.info.type_of_span(ident.span())
376 }
377 ast::Pattern::Parenthesized(paren_expr) => {
378 self.destruct_let(paren_expr.expr().to_untyped().cast()?, node)
379 }
380 ast::Pattern::Destructuring(_d) => {
382 let _ = node;
383 None
384 }
385 }
386 }
387
388 fn check_element_of<T>(&mut self, ty: &Ty, pol: bool, context: &LinkedNode, checker: &mut T)
389 where
390 T: PostSigChecker,
391 {
392 let mut checker = PostSigCheckWorker(self, checker);
393 ty.sig_surface(pol, sig_context_of(context), &mut checker)
394 }
395}
396
397trait PostSigChecker {
398 fn check(
399 &mut self,
400 checker: &mut PostTypeChecker,
401 sig: Sig,
402 args: &[Interned<ArgsTy>],
403 pol: bool,
404 ) -> Option<()>;
405}
406
407impl<T> PostSigChecker for T
408where
409 T: FnMut(&mut PostTypeChecker, Sig, &[Interned<ArgsTy>], bool) -> Option<()>,
410{
411 fn check(
412 &mut self,
413 checker: &mut PostTypeChecker,
414 sig: Sig,
415 args: &[Interned<ArgsTy>],
416 pol: bool,
417 ) -> Option<()> {
418 self(checker, sig, args, pol)
419 }
420}
421
422#[derive(BindTyCtx)]
423#[bind(0)]
424struct PostSigCheckWorker<'x, 'a, T>(&'x mut PostTypeChecker<'a>, &'x mut T);
425
426impl<T: PostSigChecker> SigChecker for PostSigCheckWorker<'_, '_, T> {
427 fn check(
428 &mut self,
429 sig: Sig,
430 args: &mut crate::analysis::SigCheckContext,
431 pol: bool,
432 ) -> Option<()> {
433 self.1.check(self.0, sig, &args.args, pol)
434 }
435}
436
437fn sig_context_of(context: &LinkedNode) -> SigSurfaceKind {
438 match context.kind() {
439 SyntaxKind::Parenthesized => SigSurfaceKind::ArrayOrDict,
440 SyntaxKind::Array => {
441 let arr = context.cast::<ast::Array>();
442 if arr.is_some_and(|arr| arr.items().next().is_some()) {
443 SigSurfaceKind::Array
444 } else {
445 SigSurfaceKind::ArrayOrDict
446 }
447 }
448 SyntaxKind::Dict => SigSurfaceKind::Dict,
449 _ => SigSurfaceKind::Array,
450 }
451}