diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index ee4180d8..d9e93f53 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -97,7 +97,8 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy): hidden_states, input_ids, attention_masks """ text_encoder = models[0] - assert isinstance(text_encoder, Gemma2Model) + # Check model or torch dynamo OptimizedModule + assert isinstance(text_encoder, Gemma2Model) or isinstance(text_encoder._orig_mod, Gemma2Model), f"text encoder is not Gemma2Model {text_encoder.__class__.__name__}" input_ids, attention_masks = tokens outputs = text_encoder(