Skip to content

Commit

Permalink
小优化
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jan 20, 2025
1 parent ec1a067 commit 03ce4f7
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1167,14 +1167,15 @@ namespace fastllm {
} else if (device == DataDevice::CUDA) {
int sourceDevice = this->dataDeviceIds.size() == 0 ? 0 : this->dataDeviceIds[0];
int destDevice = deviceIds.size() == 0 ? 0 : deviceIds[0];
FastllmCudaSetDevice(destDevice);
void *newCudaData = FastllmCudaMalloc(expansionBytes);

FastllmCudaMemcpyBetweenDevices(destDevice, newCudaData, sourceDevice, this->cudaData, expansionBytes);
FastllmCudaSetDevice(sourceDevice);
FastllmCudaFree(this->cudaData);
this->cudaData = newCudaData;
FastllmCudaSetDevice(destDevice);
if (sourceDevice != destDevice) {
FastllmCudaSetDevice(destDevice);
void *newCudaData = FastllmCudaMalloc(expansionBytes);
FastllmCudaMemcpyBetweenDevices(destDevice, newCudaData, sourceDevice, this->cudaData, expansionBytes);
FastllmCudaSetDevice(sourceDevice);
FastllmCudaFree(this->cudaData);
this->cudaData = newCudaData;
FastllmCudaSetDevice(destDevice);
}
}
}
#endif
Expand Down

0 comments on commit 03ce4f7

Please sign in to comment.