keynes42 commited on
Commit
e38adab
·
verified ·
1 Parent(s): fbd7ae6

Update app.py

Browse files

Add flash attention

Files changed (1) hide show
  1. app.py +1 -0
app.py CHANGED
@@ -75,6 +75,7 @@ class BasicModel:
75
  model_id,
76
  torch_dtype=torch.float16,
77
  device_map="auto", ## auto-distributes to GPU
 
78
  token=hf_token,
79
  trust_remote_code=True, ## <- Use the custom code that isn't part of the base transformers library yet
80
  quantization_config=quantization_config ## <- Load 4-bit quantization because vRAM is not big enough
 
75
  model_id,
76
  torch_dtype=torch.float16,
77
  device_map="auto", ## auto-distributes to GPU
78
+ attn_implementation="flash_attention_2",
79
  token=hf_token,
80
  trust_remote_code=True, ## <- Use the custom code that isn't part of the base transformers library yet
81
  quantization_config=quantization_config ## <- Load 4-bit quantization because vRAM is not big enough