-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: gemma2 samples with accelerated TPU and GPU (#4395)
- Loading branch information
Stepan Rasputny
authored
Oct 7, 2024
1 parent
fc150e3
commit 8957a58
Showing
6 changed files
with
521 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
// Copyright 2024 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// https://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package snippets | ||
|
||
// [START generativeaionvertexai_gemma2_predict_gpu] | ||
import ( | ||
"context" | ||
"fmt" | ||
"io" | ||
|
||
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb" | ||
|
||
"google.golang.org/protobuf/types/known/structpb" | ||
) | ||
|
||
// predictGPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with GPU accelerators. | ||
func predictGPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error { | ||
ctx := context.Background() | ||
|
||
// Note: client can be initialized in the following way: | ||
// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location) | ||
// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint)) | ||
// if err != nil { | ||
// return fmt.Errorf("unable to create prediction client: %v", err) | ||
// } | ||
// defer client.Close() | ||
|
||
gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID) | ||
prompt := "Why is the sky blue?" | ||
parameters := map[string]interface{}{ | ||
"temperature": 0.9, | ||
"maxOutputTokens": 1024, | ||
"topP": 1.0, | ||
"topK": 1, | ||
} | ||
|
||
// Encapsulate the prompt in a correct format for TPUs. | ||
// Pay attention that prompt should be set in "inputs" field. | ||
// Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.9}}] | ||
promptValue, err := structpb.NewValue(map[string]interface{}{ | ||
"inputs": prompt, | ||
"parameters": parameters, | ||
}) | ||
if err != nil { | ||
fmt.Fprintf(w, "unable to convert prompt to Value: %v", err) | ||
return err | ||
} | ||
|
||
req := &aiplatformpb.PredictRequest{ | ||
Endpoint: gemma2Endpoint, | ||
Instances: []*structpb.Value{promptValue}, | ||
} | ||
|
||
resp, err := client.Predict(ctx, req) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
prediction := resp.GetPredictions() | ||
value := prediction[0].GetStringValue() | ||
fmt.Fprintf(w, "%v", value) | ||
|
||
return nil | ||
} | ||
|
||
// [END generativeaionvertexai_gemma2_predict_gpu] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// Copyright 2024 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// https://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package snippets | ||
|
||
// [START generativeaionvertexai_gemma2_predict_tpu] | ||
import ( | ||
"context" | ||
"fmt" | ||
"io" | ||
|
||
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb" | ||
|
||
"google.golang.org/protobuf/types/known/structpb" | ||
) | ||
|
||
// predictTPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accelerators. | ||
func predictTPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error { | ||
ctx := context.Background() | ||
|
||
// Note: client can be initialized in the following way: | ||
// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location) | ||
// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint)) | ||
// if err != nil { | ||
// return fmt.Errorf("unable to create prediction client: %v", err) | ||
// } | ||
// defer client.Close() | ||
|
||
gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID) | ||
prompt := "Why is the sky blue?" | ||
parameters := map[string]interface{}{ | ||
"temperature": 0.9, | ||
"maxOutputTokens": 1024, | ||
"topP": 1.0, | ||
"topK": 1, | ||
} | ||
|
||
// Encapsulate the prompt in a correct format for TPUs. | ||
// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}] | ||
promptValue, err := structpb.NewValue(map[string]interface{}{ | ||
"prompt": prompt, | ||
"parameters": parameters, | ||
}) | ||
if err != nil { | ||
fmt.Fprintf(w, "unable to convert prompt to Value: %v", err) | ||
return err | ||
} | ||
|
||
req := &aiplatformpb.PredictRequest{ | ||
Endpoint: gemma2Endpoint, | ||
Instances: []*structpb.Value{promptValue}, | ||
} | ||
|
||
resp, err := client.Predict(ctx, req) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
prediction := resp.GetPredictions() | ||
value := prediction[0].GetStringValue() | ||
fmt.Fprintf(w, "%v", value) | ||
|
||
return nil | ||
} | ||
|
||
// [END generativeaionvertexai_gemma2_predict_tpu] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
// Copyright 2024 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package snippets | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"strings" | ||
"testing" | ||
|
||
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb" | ||
"github.com/GoogleCloudPlatform/golang-samples/internal/testutil" | ||
"github.com/googleapis/gax-go/v2" | ||
) | ||
|
||
func TestPredictGemma2(t *testing.T) { | ||
tc := testutil.SystemTest(t) | ||
|
||
projectID := tc.ProjectID | ||
var buf bytes.Buffer | ||
client := PredictionsClient{} | ||
|
||
t.Run("GPU predict", func(t *testing.T) { | ||
buf.Reset() | ||
// Mock ID used to check if GPU was called | ||
if err := predictGPU(&buf, client, projectID, GPUEndpointRegion, GPUEndpointID); err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
if got := buf.String(); !strings.Contains(got, "Rayleigh scattering") { | ||
t.Error("generated text content not found in response") | ||
} | ||
}) | ||
|
||
t.Run("TPU predict", func(t *testing.T) { | ||
buf.Reset() | ||
// Mock ID used to check if TPU was called | ||
if err := predictTPU(&buf, client, projectID, TPUEndpointRegion, TPUEndpointID); err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
if got := buf.String(); !strings.Contains(got, "Rayleigh scattering") { | ||
t.Error("generated text content not found in response") | ||
} | ||
}) | ||
} | ||
|
||
type PredictClientInterface interface { | ||
Close() error | ||
Predict(ctx context.Context, req *aiplatformpb.PredictRequest, opts ...gax.CallOption) (*aiplatformpb.PredictResponse, error) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
module github.com/GoogleCloudPlatform/golang-samples/gemma2 | ||
|
||
go 1.21 | ||
|
||
require ( | ||
cloud.google.com/go/aiplatform v1.68.0 | ||
github.com/GoogleCloudPlatform/golang-samples v0.0.0-20240918200157-a00ca430a14b | ||
github.com/googleapis/gax-go/v2 v2.13.0 | ||
google.golang.org/protobuf v1.34.2 | ||
) | ||
|
||
require ( | ||
cloud.google.com/go v0.115.1 // indirect | ||
cloud.google.com/go/auth v0.9.3 // indirect | ||
cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect | ||
cloud.google.com/go/compute/metadata v0.5.0 // indirect | ||
cloud.google.com/go/iam v1.2.0 // indirect | ||
cloud.google.com/go/longrunning v0.6.0 // indirect | ||
cloud.google.com/go/storage v1.43.0 // indirect | ||
github.com/felixge/httpsnoop v1.0.4 // indirect | ||
github.com/go-logr/logr v1.4.2 // indirect | ||
github.com/go-logr/stdr v1.2.2 // indirect | ||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect | ||
github.com/google/s2a-go v0.1.8 // indirect | ||
github.com/google/uuid v1.6.0 // indirect | ||
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect | ||
go.opencensus.io v0.24.0 // indirect | ||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect | ||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect | ||
go.opentelemetry.io/otel v1.29.0 // indirect | ||
go.opentelemetry.io/otel/metric v1.29.0 // indirect | ||
go.opentelemetry.io/otel/trace v1.29.0 // indirect | ||
golang.org/x/crypto v0.27.0 // indirect | ||
golang.org/x/net v0.29.0 // indirect | ||
golang.org/x/oauth2 v0.23.0 // indirect | ||
golang.org/x/sync v0.8.0 // indirect | ||
golang.org/x/sys v0.25.0 // indirect | ||
golang.org/x/text v0.18.0 // indirect | ||
golang.org/x/time v0.6.0 // indirect | ||
google.golang.org/api v0.197.0 // indirect | ||
google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 // indirect | ||
google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed // indirect | ||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect | ||
google.golang.org/grpc v1.66.1 // indirect | ||
) |
Oops, something went wrong.