Skip to content

Commit

Permalink
Add basic mutex synchronization to GcsFileSystem and GoogleAuthProvider.
Browse files Browse the repository at this point in the history
Change: 132623564
  • Loading branch information
rinugun authored and tensorflower-gardener committed Sep 9, 2016
1 parent 5665a31 commit 10272d2
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
4 changes: 3 additions & 1 deletion tensorflow/core/platform/cloud/auth_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ class AuthProvider {
public:
virtual ~AuthProvider() {}

/// Returns the short-term authentication bearer token.
/// \brief Returns the short-term authentication bearer token.
///
/// Safe for concurrent use by multiple threads.
virtual Status GetToken(string* t) = 0;

static Status GetToken(AuthProvider* provider, string* token) {
Expand Down
18 changes: 11 additions & 7 deletions tensorflow/core/platform/cloud/gcs_file_system.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/platform/cloud/gcs_file_system.h"
#include <stdio.h>
#include <unistd.h>
#include <algorithm>
Expand All @@ -27,12 +28,13 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/cloud/gcs_file_system.h"
#include "tensorflow/core/platform/cloud/google_auth_provider.h"
#include "tensorflow/core/platform/cloud/time_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/thread_annotations.h"

namespace tensorflow {

Expand Down Expand Up @@ -95,9 +97,10 @@ class GcsRandomAccessFile : public RandomAccessFile {
http_request_factory_(http_request_factory),
read_ahead_bytes_(read_ahead_bytes) {}

/// The implementation of reads with a read-ahead buffer.
/// The implementation of reads with a read-ahead buffer. Thread-safe.
Status Read(uint64 offset, size_t n, StringPiece* result,
char* scratch) const override {
mutex_lock lock(mu_);
const bool range_start_included = offset >= buffer_start_offset_;
const bool range_end_included =
offset + n <= buffer_start_offset_ + buffer_content_size_;
Expand Down Expand Up @@ -169,12 +172,13 @@ class GcsRandomAccessFile : public RandomAccessFile {

// The buffer-related members need to be mutable, because they are modified
// by the const Read() method.
mutable std::unique_ptr<char[]> buffer_;
mutable size_t buffer_size_ = 0;
mutable mutex mu_;
mutable std::unique_ptr<char[]> buffer_ GUARDED_BY(mu_);
mutable size_t buffer_size_ GUARDED_BY(mu_) = 0;
// The original file offset of the first byte in the buffer.
mutable size_t buffer_start_offset_ = 0;
mutable size_t buffer_content_size_ = 0;
mutable bool buffer_reached_eof_ = false;
mutable size_t buffer_start_offset_ GUARDED_BY(mu_) = 0;
mutable size_t buffer_content_size_ GUARDED_BY(mu_) = 0;
mutable bool buffer_reached_eof_ GUARDED_BY(mu_) = false;
};

/// \brief GCS-based implementation of a writeable file.
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/platform/cloud/google_auth_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ GoogleAuthProvider::GoogleAuthProvider(
env_(env) {}

Status GoogleAuthProvider::GetToken(string* t) {
mutex_lock lock(mu_);
const uint64 now_sec = env_->NowSeconds();

if (!current_token_.empty() &&
Expand Down
11 changes: 8 additions & 3 deletions tensorflow/core/platform/cloud/google_auth_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ limitations under the License.
#include <memory>
#include "tensorflow/core/platform/cloud/auth_provider.h"
#include "tensorflow/core/platform/cloud/oauth_client.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"

namespace tensorflow {

Expand All @@ -31,7 +33,9 @@ class GoogleAuthProvider : public AuthProvider {
std::unique_ptr<HttpRequest::Factory> http_request_factory, Env* env);
virtual ~GoogleAuthProvider() {}

/// Returns the short-term authentication bearer token.
/// \brief Returns the short-term authentication bearer token.
///
/// Safe for concurrent use by multiple threads.
Status GetToken(string* token) override;

private:
Expand All @@ -47,8 +51,9 @@ class GoogleAuthProvider : public AuthProvider {
std::unique_ptr<OAuthClient> oauth_client_;
std::unique_ptr<HttpRequest::Factory> http_request_factory_;
Env* env_;
string current_token_;
uint64 expiration_timestamp_sec_ = 0;
mutex mu_;
string current_token_ GUARDED_BY(mu_);
uint64 expiration_timestamp_sec_ GUARDED_BY(mu_) = 0;
TF_DISALLOW_COPY_AND_ASSIGN(GoogleAuthProvider);
};

Expand Down

0 comments on commit 10272d2

Please sign in to comment.