@ -138,9 +138,11 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
super().__init__(cpu_group)
|
||||
self.handle_cache = Cache()
|
||||
|
||||
# This is the DeepEP default. Stick to it till we can establish
|
||||
# reasonable defaults based on profiling.
|
||||
self.num_sms = 20
|
||||
# Use all SMs for all2all communication
|
||||
# This will need to be adjusted for dual-batch overlap
|
||||
device = self.dp_group.device
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
self.num_sms = props.multi_processor_count
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user